1# -*- encoding: utf-8 -*- 2# Samba traffic replay and learning 3# 4# Copyright (C) Catalyst IT Ltd. 2017 5# 6# This program is free software; you can redistribute it and/or modify 7# it under the terms of the GNU General Public License as published by 8# the Free Software Foundation; either version 3 of the License, or 9# (at your option) any later version. 10# 11# This program is distributed in the hope that it will be useful, 12# but WITHOUT ANY WARRANTY; without even the implied warranty of 13# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14# GNU General Public License for more details. 15# 16# You should have received a copy of the GNU General Public License 17# along with this program. If not, see <http://www.gnu.org/licenses/>. 18# 19from __future__ import print_function, division 20 21import time 22import os 23import random 24import json 25import math 26import sys 27import signal 28from errno import ECHILD, ESRCH 29 30from collections import OrderedDict, Counter, defaultdict, namedtuple 31from dns.resolver import query as dns_query 32 33from samba.emulate import traffic_packets 34from samba.samdb import SamDB 35import ldb 36from ldb import LdbError 37from samba.dcerpc import ClientConnection 38from samba.dcerpc import security, drsuapi, lsa 39from samba.dcerpc import netlogon 40from samba.dcerpc.netlogon import netr_Authenticator 41from samba.dcerpc import srvsvc 42from samba.dcerpc import samr 43from samba.drs_utils import drs_DsBind 44import traceback 45from samba.credentials import Credentials, DONT_USE_KERBEROS, MUST_USE_KERBEROS 46from samba.auth import system_session 47from samba.dsdb import ( 48 UF_NORMAL_ACCOUNT, 49 UF_SERVER_TRUST_ACCOUNT, 50 UF_TRUSTED_FOR_DELEGATION, 51 UF_WORKSTATION_TRUST_ACCOUNT 52) 53from samba.dcerpc.misc import SEC_CHAN_BDC 54from samba import gensec 55from samba import sd_utils 56from samba.compat import get_string 57from samba.logger import get_samba_logger 58import bisect 59 60CURRENT_MODEL_VERSION = 2 # save as this 61REQUIRED_MODEL_VERSION = 2 # load accepts this or greater 62SLEEP_OVERHEAD = 3e-4 63 64# we don't use None, because it complicates [de]serialisation 65NON_PACKET = '-' 66 67CLIENT_CLUES = { 68 ('dns', '0'): 1.0, # query 69 ('smb', '0x72'): 1.0, # Negotiate protocol 70 ('ldap', '0'): 1.0, # bind 71 ('ldap', '3'): 1.0, # searchRequest 72 ('ldap', '2'): 1.0, # unbindRequest 73 ('cldap', '3'): 1.0, 74 ('dcerpc', '11'): 1.0, # bind 75 ('dcerpc', '14'): 1.0, # Alter_context 76 ('nbns', '0'): 1.0, # query 77} 78 79SERVER_CLUES = { 80 ('dns', '1'): 1.0, # response 81 ('ldap', '1'): 1.0, # bind response 82 ('ldap', '4'): 1.0, # search result 83 ('ldap', '5'): 1.0, # search done 84 ('cldap', '5'): 1.0, 85 ('dcerpc', '12'): 1.0, # bind_ack 86 ('dcerpc', '13'): 1.0, # bind_nak 87 ('dcerpc', '15'): 1.0, # Alter_context response 88} 89 90SKIPPED_PROTOCOLS = {"smb", "smb2", "browser", "smb_netlogon"} 91 92WAIT_SCALE = 10.0 93WAIT_THRESHOLD = (1.0 / WAIT_SCALE) 94NO_WAIT_LOG_TIME_RANGE = (-10, -3) 95 96# DEBUG_LEVEL can be changed by scripts with -d 97DEBUG_LEVEL = 0 98 99LOGGER = get_samba_logger(name=__name__) 100 101 102def debug(level, msg, *args): 103 """Print a formatted debug message to standard error. 104 105 106 :param level: The debug level, message will be printed if it is <= the 107 currently set debug level. The debug level can be set with 108 the -d option. 109 :param msg: The message to be logged, can contain C-Style format 110 specifiers 111 :param args: The parameters required by the format specifiers 112 """ 113 if level <= DEBUG_LEVEL: 114 if not args: 115 print(msg, file=sys.stderr) 116 else: 117 print(msg % tuple(args), file=sys.stderr) 118 119 120def debug_lineno(*args): 121 """ Print an unformatted log message to stderr, contaning the line number 122 """ 123 tb = traceback.extract_stack(limit=2) 124 print((" %s:" "\033[01;33m" 125 "%s " "\033[00m" % (tb[0][2], tb[0][1])), end=' ', 126 file=sys.stderr) 127 for a in args: 128 print(a, file=sys.stderr) 129 print(file=sys.stderr) 130 sys.stderr.flush() 131 132 133def random_colour_print(seeds): 134 """Return a function that prints a coloured line to stderr. The colour 135 of the line depends on a sort of hash of the integer arguments.""" 136 if seeds: 137 s = 214 138 for x in seeds: 139 s += 17 140 s *= x 141 s %= 214 142 prefix = "\033[38;5;%dm" % (18 + s) 143 144 def p(*args): 145 if DEBUG_LEVEL > 0: 146 for a in args: 147 print("%s%s\033[00m" % (prefix, a), file=sys.stderr) 148 else: 149 def p(*args): 150 if DEBUG_LEVEL > 0: 151 for a in args: 152 print(a, file=sys.stderr) 153 154 return p 155 156 157class FakePacketError(Exception): 158 pass 159 160 161class Packet(object): 162 """Details of a network packet""" 163 __slots__ = ('timestamp', 164 'ip_protocol', 165 'stream_number', 166 'src', 167 'dest', 168 'protocol', 169 'opcode', 170 'desc', 171 'extra', 172 'endpoints') 173 def __init__(self, timestamp, ip_protocol, stream_number, src, dest, 174 protocol, opcode, desc, extra): 175 self.timestamp = timestamp 176 self.ip_protocol = ip_protocol 177 self.stream_number = stream_number 178 self.src = src 179 self.dest = dest 180 self.protocol = protocol 181 self.opcode = opcode 182 self.desc = desc 183 self.extra = extra 184 if self.src < self.dest: 185 self.endpoints = (self.src, self.dest) 186 else: 187 self.endpoints = (self.dest, self.src) 188 189 @classmethod 190 def from_line(cls, line): 191 fields = line.rstrip('\n').split('\t') 192 (timestamp, 193 ip_protocol, 194 stream_number, 195 src, 196 dest, 197 protocol, 198 opcode, 199 desc) = fields[:8] 200 extra = fields[8:] 201 202 timestamp = float(timestamp) 203 src = int(src) 204 dest = int(dest) 205 206 return cls(timestamp, ip_protocol, stream_number, src, dest, 207 protocol, opcode, desc, extra) 208 209 def as_summary(self, time_offset=0.0): 210 """Format the packet as a traffic_summary line. 211 """ 212 extra = '\t'.join(self.extra) 213 t = self.timestamp + time_offset 214 return (t, '%f\t%s\t%s\t%d\t%d\t%s\t%s\t%s\t%s' % 215 (t, 216 self.ip_protocol, 217 self.stream_number or '', 218 self.src, 219 self.dest, 220 self.protocol, 221 self.opcode, 222 self.desc, 223 extra)) 224 225 def __str__(self): 226 return ("%.3f: %d -> %d; ip %s; strm %s; prot %s; op %s; desc %s %s" % 227 (self.timestamp, self.src, self.dest, self.ip_protocol or '-', 228 self.stream_number, self.protocol, self.opcode, self.desc, 229 ('«' + ' '.join(self.extra) + '»' if self.extra else ''))) 230 231 def __repr__(self): 232 return "<Packet @%s>" % self 233 234 def copy(self): 235 return self.__class__(self.timestamp, 236 self.ip_protocol, 237 self.stream_number, 238 self.src, 239 self.dest, 240 self.protocol, 241 self.opcode, 242 self.desc, 243 self.extra) 244 245 def as_packet_type(self): 246 t = '%s:%s' % (self.protocol, self.opcode) 247 return t 248 249 def client_score(self): 250 """A positive number means we think it is a client; a negative number 251 means we think it is a server. Zero means no idea. range: -1 to 1. 252 """ 253 key = (self.protocol, self.opcode) 254 if key in CLIENT_CLUES: 255 return CLIENT_CLUES[key] 256 if key in SERVER_CLUES: 257 return -SERVER_CLUES[key] 258 return 0.0 259 260 def play(self, conversation, context): 261 """Send the packet over the network, if required. 262 263 Some packets are ignored, i.e. for protocols not handled, 264 server response messages, or messages that are generated by the 265 protocol layer associated with other packets. 266 """ 267 fn_name = 'packet_%s_%s' % (self.protocol, self.opcode) 268 try: 269 fn = getattr(traffic_packets, fn_name) 270 271 except AttributeError as e: 272 print("Conversation(%s) Missing handler %s" % 273 (conversation.conversation_id, fn_name), 274 file=sys.stderr) 275 return 276 277 # Don't display a message for kerberos packets, they're not directly 278 # generated they're used to indicate kerberos should be used 279 if self.protocol != "kerberos": 280 debug(2, "Conversation(%s) Calling handler %s" % 281 (conversation.conversation_id, fn_name)) 282 283 start = time.time() 284 try: 285 if fn(self, conversation, context): 286 # Only collect timing data for functions that generate 287 # network traffic, or fail 288 end = time.time() 289 duration = end - start 290 print("%f\t%s\t%s\t%s\t%f\tTrue\t" % 291 (end, conversation.conversation_id, self.protocol, 292 self.opcode, duration)) 293 except Exception as e: 294 end = time.time() 295 duration = end - start 296 print("%f\t%s\t%s\t%s\t%f\tFalse\t%s" % 297 (end, conversation.conversation_id, self.protocol, 298 self.opcode, duration, e)) 299 300 def __cmp__(self, other): 301 return self.timestamp - other.timestamp 302 303 def is_really_a_packet(self, missing_packet_stats=None): 304 return is_a_real_packet(self.protocol, self.opcode) 305 306 307def is_a_real_packet(protocol, opcode): 308 """Is the packet one that can be ignored? 309 310 If so removing it will have no effect on the replay 311 """ 312 if protocol in SKIPPED_PROTOCOLS: 313 # Ignore any packets for the protocols we're not interested in. 314 return False 315 if protocol == "ldap" and opcode == '': 316 # skip ldap continuation packets 317 return False 318 319 fn_name = 'packet_%s_%s' % (protocol, opcode) 320 fn = getattr(traffic_packets, fn_name, None) 321 if fn is None: 322 LOGGER.debug("missing packet %s" % fn_name, file=sys.stderr) 323 return False 324 if fn is traffic_packets.null_packet: 325 return False 326 return True 327 328 329def is_a_traffic_generating_packet(protocol, opcode): 330 """Return true if a packet generates traffic in its own right. Some of 331 these will generate traffic in certain contexts (e.g. ldap unbind 332 after a bind) but not if the conversation consists only of these packets. 333 """ 334 if protocol == 'wait': 335 return False 336 337 if (protocol, opcode) in ( 338 ('kerberos', ''), 339 ('ldap', '2'), 340 ('dcerpc', '15'), 341 ('dcerpc', '16')): 342 return False 343 344 return is_a_real_packet(protocol, opcode) 345 346 347class ReplayContext(object): 348 """State/Context for a conversation between an simulated client and a 349 server. Some of the context is shared amongst all conversations 350 and should be generated before the fork, while other context is 351 specific to a particular conversation and should be generated 352 *after* the fork, in generate_process_local_config(). 353 """ 354 def __init__(self, 355 server=None, 356 lp=None, 357 creds=None, 358 total_conversations=None, 359 badpassword_frequency=None, 360 prefer_kerberos=None, 361 tempdir=None, 362 statsdir=None, 363 ou=None, 364 base_dn=None, 365 domain=os.environ.get("DOMAIN"), 366 domain_sid=None, 367 instance_id=None): 368 self.server = server 369 self.netlogon_connection = None 370 self.creds = creds 371 self.lp = lp 372 if prefer_kerberos: 373 self.kerberos_state = MUST_USE_KERBEROS 374 else: 375 self.kerberos_state = DONT_USE_KERBEROS 376 self.ou = ou 377 self.base_dn = base_dn 378 self.domain = domain 379 self.statsdir = statsdir 380 self.global_tempdir = tempdir 381 self.domain_sid = domain_sid 382 self.realm = lp.get('realm') 383 self.instance_id = instance_id 384 385 # Bad password attempt controls 386 self.badpassword_frequency = badpassword_frequency 387 self.last_lsarpc_bad = False 388 self.last_lsarpc_named_bad = False 389 self.last_simple_bind_bad = False 390 self.last_bind_bad = False 391 self.last_srvsvc_bad = False 392 self.last_drsuapi_bad = False 393 self.last_netlogon_bad = False 394 self.last_samlogon_bad = False 395 self.total_conversations = total_conversations 396 self.generate_ldap_search_tables() 397 398 def generate_ldap_search_tables(self): 399 session = system_session() 400 401 db = SamDB(url="ldap://%s" % self.server, 402 session_info=session, 403 credentials=self.creds, 404 lp=self.lp) 405 406 res = db.search(db.domain_dn(), 407 scope=ldb.SCOPE_SUBTREE, 408 controls=["paged_results:1:1000"], 409 attrs=['dn']) 410 411 # find a list of dns for each pattern 412 # e.g. CN,CN,CN,DC,DC 413 dn_map = {} 414 attribute_clue_map = { 415 'invocationId': [] 416 } 417 418 for r in res: 419 dn = str(r.dn) 420 pattern = ','.join(x.lstrip()[:2] for x in dn.split(',')).upper() 421 dns = dn_map.setdefault(pattern, []) 422 dns.append(dn) 423 if dn.startswith('CN=NTDS Settings,'): 424 attribute_clue_map['invocationId'].append(dn) 425 426 # extend the map in case we are working with a different 427 # number of DC components. 428 # for k, v in self.dn_map.items(): 429 # print >>sys.stderr, k, len(v) 430 431 for k in list(dn_map.keys()): 432 if k[-3:] != ',DC': 433 continue 434 p = k[:-3] 435 while p[-3:] == ',DC': 436 p = p[:-3] 437 for i in range(5): 438 p += ',DC' 439 if p != k and p in dn_map: 440 print('dn_map collison %s %s' % (k, p), 441 file=sys.stderr) 442 continue 443 dn_map[p] = dn_map[k] 444 445 self.dn_map = dn_map 446 self.attribute_clue_map = attribute_clue_map 447 448 # pre-populate DN-based search filters (it's simplest to generate them 449 # once, when the test starts). These are used by guess_search_filter() 450 # to avoid full-scans 451 self.search_filters = {} 452 453 # lookup all the GPO DNs 454 res = db.search(db.domain_dn(), scope=ldb.SCOPE_SUBTREE, attrs=['dn'], 455 expression='(objectclass=groupPolicyContainer)') 456 gpos_by_dn = "".join("(distinguishedName={0})".format(msg['dn']) for msg in res) 457 458 # a search for the 'gPCFileSysPath' attribute is probably a GPO search 459 # (as per the MS-GPOL spec) which searches for GPOs by DN 460 self.search_filters['gPCFileSysPath'] = "(|{0})".format(gpos_by_dn) 461 462 # likewise, a search for gpLink is probably the Domain SOM search part 463 # of the MS-GPOL, in which case it's looking up a few OUs by DN 464 ou_str = "" 465 for ou in ["Domain Controllers,", "traffic_replay,", ""]: 466 ou_str += "(distinguishedName={0}{1})".format(ou, db.domain_dn()) 467 self.search_filters['gpLink'] = "(|{0})".format(ou_str) 468 469 # The CEP Web Service can query the AD DC to get pKICertificateTemplate 470 # objects (as per MS-WCCE) 471 self.search_filters['pKIExtendedKeyUsage'] = \ 472 '(objectCategory=pKICertificateTemplate)' 473 474 # assume that anything querying the usnChanged is some kind of 475 # synchronization tool, e.g. AD Change Detection Connector 476 res = db.search('', scope=ldb.SCOPE_BASE, attrs=['highestCommittedUSN']) 477 self.search_filters['usnChanged'] = \ 478 '(usnChanged>={0})'.format(res[0]['highestCommittedUSN']) 479 480 # The traffic_learner script doesn't preserve the LDAP search filter, and 481 # having no filter can result in a full DB scan. This is costly for a large 482 # DB, and not necessarily representative of real world traffic. As there 483 # several standard LDAP queries that get used by AD tools, we can apply 484 # some logic and guess what the search filter might have been originally. 485 def guess_search_filter(self, attrs, dn_sig, dn): 486 487 # there are some standard spec-based searches that query fairly unique 488 # attributes. Check if the search is likely one of these 489 for key in self.search_filters.keys(): 490 if key in attrs: 491 return self.search_filters[key] 492 493 # if it's the top-level domain, assume we're looking up a single user, 494 # e.g. like powershell Get-ADUser or a similar tool 495 if dn_sig == 'DC,DC': 496 random_user_id = random.random() % self.total_conversations 497 account_name = user_name(self.instance_id, random_user_id) 498 return '(&(sAMAccountName=%s)(objectClass=user))' % account_name 499 500 # otherwise just return everything in the sub-tree 501 return '(objectClass=*)' 502 503 def generate_process_local_config(self, account, conversation): 504 self.ldap_connections = [] 505 self.dcerpc_connections = [] 506 self.lsarpc_connections = [] 507 self.lsarpc_connections_named = [] 508 self.drsuapi_connections = [] 509 self.srvsvc_connections = [] 510 self.samr_contexts = [] 511 self.netbios_name = account.netbios_name 512 self.machinepass = account.machinepass 513 self.username = account.username 514 self.userpass = account.userpass 515 516 self.tempdir = mk_masked_dir(self.global_tempdir, 517 'conversation-%d' % 518 conversation.conversation_id) 519 520 self.lp.set("private dir", self.tempdir) 521 self.lp.set("lock dir", self.tempdir) 522 self.lp.set("state directory", self.tempdir) 523 self.lp.set("tls verify peer", "no_check") 524 525 self.remoteAddress = "/root/ncalrpc_as_system" 526 self.samlogon_dn = ("cn=%s,%s" % 527 (self.netbios_name, self.ou)) 528 self.user_dn = ("cn=%s,%s" % 529 (self.username, self.ou)) 530 531 self.generate_machine_creds() 532 self.generate_user_creds() 533 534 def with_random_bad_credentials(self, f, good, bad, failed_last_time): 535 """Execute the supplied logon function, randomly choosing the 536 bad credentials. 537 538 Based on the frequency in badpassword_frequency randomly perform the 539 function with the supplied bad credentials. 540 If run with bad credentials, the function is re-run with the good 541 credentials. 542 failed_last_time is used to prevent consecutive bad credential 543 attempts. So the over all bad credential frequency will be lower 544 than that requested, but not significantly. 545 """ 546 if not failed_last_time: 547 if (self.badpassword_frequency and 548 random.random() < self.badpassword_frequency): 549 try: 550 f(bad) 551 except Exception: 552 # Ignore any exceptions as the operation may fail 553 # as it's being performed with bad credentials 554 pass 555 failed_last_time = True 556 else: 557 failed_last_time = False 558 559 result = f(good) 560 return (result, failed_last_time) 561 562 def generate_user_creds(self): 563 """Generate the conversation specific user Credentials. 564 565 Each Conversation has an associated user account used to simulate 566 any non Administrative user traffic. 567 568 Generates user credentials with good and bad passwords and ldap 569 simple bind credentials with good and bad passwords. 570 """ 571 self.user_creds = Credentials() 572 self.user_creds.guess(self.lp) 573 self.user_creds.set_workstation(self.netbios_name) 574 self.user_creds.set_password(self.userpass) 575 self.user_creds.set_username(self.username) 576 self.user_creds.set_domain(self.domain) 577 self.user_creds.set_kerberos_state(self.kerberos_state) 578 579 self.user_creds_bad = Credentials() 580 self.user_creds_bad.guess(self.lp) 581 self.user_creds_bad.set_workstation(self.netbios_name) 582 self.user_creds_bad.set_password(self.userpass[:-4]) 583 self.user_creds_bad.set_username(self.username) 584 self.user_creds_bad.set_kerberos_state(self.kerberos_state) 585 586 # Credentials for ldap simple bind. 587 self.simple_bind_creds = Credentials() 588 self.simple_bind_creds.guess(self.lp) 589 self.simple_bind_creds.set_workstation(self.netbios_name) 590 self.simple_bind_creds.set_password(self.userpass) 591 self.simple_bind_creds.set_username(self.username) 592 self.simple_bind_creds.set_gensec_features( 593 self.simple_bind_creds.get_gensec_features() | gensec.FEATURE_SEAL) 594 self.simple_bind_creds.set_kerberos_state(self.kerberos_state) 595 self.simple_bind_creds.set_bind_dn(self.user_dn) 596 597 self.simple_bind_creds_bad = Credentials() 598 self.simple_bind_creds_bad.guess(self.lp) 599 self.simple_bind_creds_bad.set_workstation(self.netbios_name) 600 self.simple_bind_creds_bad.set_password(self.userpass[:-4]) 601 self.simple_bind_creds_bad.set_username(self.username) 602 self.simple_bind_creds_bad.set_gensec_features( 603 self.simple_bind_creds_bad.get_gensec_features() | 604 gensec.FEATURE_SEAL) 605 self.simple_bind_creds_bad.set_kerberos_state(self.kerberos_state) 606 self.simple_bind_creds_bad.set_bind_dn(self.user_dn) 607 608 def generate_machine_creds(self): 609 """Generate the conversation specific machine Credentials. 610 611 Each Conversation has an associated machine account. 612 613 Generates machine credentials with good and bad passwords. 614 """ 615 616 self.machine_creds = Credentials() 617 self.machine_creds.guess(self.lp) 618 self.machine_creds.set_workstation(self.netbios_name) 619 self.machine_creds.set_secure_channel_type(SEC_CHAN_BDC) 620 self.machine_creds.set_password(self.machinepass) 621 self.machine_creds.set_username(self.netbios_name + "$") 622 self.machine_creds.set_domain(self.domain) 623 self.machine_creds.set_kerberos_state(self.kerberos_state) 624 625 self.machine_creds_bad = Credentials() 626 self.machine_creds_bad.guess(self.lp) 627 self.machine_creds_bad.set_workstation(self.netbios_name) 628 self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_BDC) 629 self.machine_creds_bad.set_password(self.machinepass[:-4]) 630 self.machine_creds_bad.set_username(self.netbios_name + "$") 631 self.machine_creds_bad.set_kerberos_state(self.kerberos_state) 632 633 def get_matching_dn(self, pattern, attributes=None): 634 # If the pattern is an empty string, we assume ROOTDSE, 635 # Otherwise we try adding or removing DC suffixes, then 636 # shorter leading patterns until we hit one. 637 # e.g if there is no CN,CN,CN,CN,DC,DC 638 # we first try CN,CN,CN,CN,DC 639 # and CN,CN,CN,CN,DC,DC,DC 640 # then change to CN,CN,CN,DC,DC 641 # and as last resort we use the base_dn 642 attr_clue = self.attribute_clue_map.get(attributes) 643 if attr_clue: 644 return random.choice(attr_clue) 645 646 pattern = pattern.upper() 647 while pattern: 648 if pattern in self.dn_map: 649 return random.choice(self.dn_map[pattern]) 650 # chop one off the front and try it all again. 651 pattern = pattern[3:] 652 653 return self.base_dn 654 655 def get_dcerpc_connection(self, new=False): 656 guid = '12345678-1234-abcd-ef00-01234567cffb' # RPC_NETLOGON UUID 657 if self.dcerpc_connections and not new: 658 return self.dcerpc_connections[-1] 659 c = ClientConnection("ncacn_ip_tcp:%s" % self.server, 660 (guid, 1), self.lp) 661 self.dcerpc_connections.append(c) 662 return c 663 664 def get_srvsvc_connection(self, new=False): 665 if self.srvsvc_connections and not new: 666 return self.srvsvc_connections[-1] 667 668 def connect(creds): 669 return srvsvc.srvsvc("ncacn_np:%s" % (self.server), 670 self.lp, 671 creds) 672 673 (c, self.last_srvsvc_bad) = \ 674 self.with_random_bad_credentials(connect, 675 self.user_creds, 676 self.user_creds_bad, 677 self.last_srvsvc_bad) 678 679 self.srvsvc_connections.append(c) 680 return c 681 682 def get_lsarpc_connection(self, new=False): 683 if self.lsarpc_connections and not new: 684 return self.lsarpc_connections[-1] 685 686 def connect(creds): 687 binding_options = 'schannel,seal,sign' 688 return lsa.lsarpc("ncacn_ip_tcp:%s[%s]" % 689 (self.server, binding_options), 690 self.lp, 691 creds) 692 693 (c, self.last_lsarpc_bad) = \ 694 self.with_random_bad_credentials(connect, 695 self.machine_creds, 696 self.machine_creds_bad, 697 self.last_lsarpc_bad) 698 699 self.lsarpc_connections.append(c) 700 return c 701 702 def get_lsarpc_named_pipe_connection(self, new=False): 703 if self.lsarpc_connections_named and not new: 704 return self.lsarpc_connections_named[-1] 705 706 def connect(creds): 707 return lsa.lsarpc("ncacn_np:%s" % (self.server), 708 self.lp, 709 creds) 710 711 (c, self.last_lsarpc_named_bad) = \ 712 self.with_random_bad_credentials(connect, 713 self.machine_creds, 714 self.machine_creds_bad, 715 self.last_lsarpc_named_bad) 716 717 self.lsarpc_connections_named.append(c) 718 return c 719 720 def get_drsuapi_connection_pair(self, new=False, unbind=False): 721 """get a (drs, drs_handle) tuple""" 722 if self.drsuapi_connections and not new: 723 c = self.drsuapi_connections[-1] 724 return c 725 726 def connect(creds): 727 binding_options = 'seal' 728 binding_string = "ncacn_ip_tcp:%s[%s]" %\ 729 (self.server, binding_options) 730 return drsuapi.drsuapi(binding_string, self.lp, creds) 731 732 (drs, self.last_drsuapi_bad) = \ 733 self.with_random_bad_credentials(connect, 734 self.user_creds, 735 self.user_creds_bad, 736 self.last_drsuapi_bad) 737 738 (drs_handle, supported_extensions) = drs_DsBind(drs) 739 c = (drs, drs_handle) 740 self.drsuapi_connections.append(c) 741 return c 742 743 def get_ldap_connection(self, new=False, simple=False): 744 if self.ldap_connections and not new: 745 return self.ldap_connections[-1] 746 747 def simple_bind(creds): 748 """ 749 To run simple bind against Windows, we need to run 750 following commands in PowerShell: 751 752 Install-windowsfeature ADCS-Cert-Authority 753 Install-AdcsCertificationAuthority -CAType EnterpriseRootCA 754 Restart-Computer 755 756 """ 757 return SamDB('ldaps://%s' % self.server, 758 credentials=creds, 759 lp=self.lp) 760 761 def sasl_bind(creds): 762 return SamDB('ldap://%s' % self.server, 763 credentials=creds, 764 lp=self.lp) 765 if simple: 766 (samdb, self.last_simple_bind_bad) = \ 767 self.with_random_bad_credentials(simple_bind, 768 self.simple_bind_creds, 769 self.simple_bind_creds_bad, 770 self.last_simple_bind_bad) 771 else: 772 (samdb, self.last_bind_bad) = \ 773 self.with_random_bad_credentials(sasl_bind, 774 self.user_creds, 775 self.user_creds_bad, 776 self.last_bind_bad) 777 778 self.ldap_connections.append(samdb) 779 return samdb 780 781 def get_samr_context(self, new=False): 782 if not self.samr_contexts or new: 783 self.samr_contexts.append( 784 SamrContext(self.server, lp=self.lp, creds=self.creds)) 785 return self.samr_contexts[-1] 786 787 def get_netlogon_connection(self): 788 789 if self.netlogon_connection: 790 return self.netlogon_connection 791 792 def connect(creds): 793 return netlogon.netlogon("ncacn_ip_tcp:%s[schannel,seal]" % 794 (self.server), 795 self.lp, 796 creds) 797 (c, self.last_netlogon_bad) = \ 798 self.with_random_bad_credentials(connect, 799 self.machine_creds, 800 self.machine_creds_bad, 801 self.last_netlogon_bad) 802 self.netlogon_connection = c 803 return c 804 805 def guess_a_dns_lookup(self): 806 return (self.realm, 'A') 807 808 def get_authenticator(self): 809 auth = self.machine_creds.new_client_authenticator() 810 current = netr_Authenticator() 811 current.cred.data = [x if isinstance(x, int) else ord(x) 812 for x in auth["credential"]] 813 current.timestamp = auth["timestamp"] 814 815 subsequent = netr_Authenticator() 816 return (current, subsequent) 817 818 def write_stats(self, filename, **kwargs): 819 """Write arbitrary key/value pairs to a file in our stats directory in 820 order for them to be picked up later by another process working out 821 statistics.""" 822 filename = os.path.join(self.statsdir, filename) 823 f = open(filename, 'w') 824 for k, v in kwargs.items(): 825 print("%s: %s" % (k, v), file=f) 826 f.close() 827 828 829class SamrContext(object): 830 """State/Context associated with a samr connection. 831 """ 832 def __init__(self, server, lp=None, creds=None): 833 self.connection = None 834 self.handle = None 835 self.domain_handle = None 836 self.domain_sid = None 837 self.group_handle = None 838 self.user_handle = None 839 self.rids = None 840 self.server = server 841 self.lp = lp 842 self.creds = creds 843 844 def get_connection(self): 845 if not self.connection: 846 self.connection = samr.samr( 847 "ncacn_ip_tcp:%s[seal]" % (self.server), 848 lp_ctx=self.lp, 849 credentials=self.creds) 850 851 return self.connection 852 853 def get_handle(self): 854 if not self.handle: 855 c = self.get_connection() 856 self.handle = c.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED) 857 return self.handle 858 859 860class Conversation(object): 861 """Details of a converation between a simulated client and a server.""" 862 def __init__(self, start_time=None, endpoints=None, seq=(), 863 conversation_id=None): 864 self.start_time = start_time 865 self.endpoints = endpoints 866 self.packets = [] 867 self.msg = random_colour_print(endpoints) 868 self.client_balance = 0.0 869 self.conversation_id = conversation_id 870 for p in seq: 871 self.add_short_packet(*p) 872 873 def __cmp__(self, other): 874 if self.start_time is None: 875 if other.start_time is None: 876 return 0 877 return -1 878 if other.start_time is None: 879 return 1 880 return self.start_time - other.start_time 881 882 def add_packet(self, packet): 883 """Add a packet object to this conversation, making a local copy with 884 a conversation-relative timestamp.""" 885 p = packet.copy() 886 887 if self.start_time is None: 888 self.start_time = p.timestamp 889 890 if self.endpoints is None: 891 self.endpoints = p.endpoints 892 893 if p.endpoints != self.endpoints: 894 raise FakePacketError("Conversation endpoints %s don't match" 895 "packet endpoints %s" % 896 (self.endpoints, p.endpoints)) 897 898 p.timestamp -= self.start_time 899 900 if p.src == p.endpoints[0]: 901 self.client_balance -= p.client_score() 902 else: 903 self.client_balance += p.client_score() 904 905 if p.is_really_a_packet(): 906 self.packets.append(p) 907 908 def add_short_packet(self, timestamp, protocol, opcode, extra, 909 client=True, skip_unused_packets=True): 910 """Create a packet from a timestamp, and 'protocol:opcode' pair, and a 911 (possibly empty) list of extra data. If client is True, assume 912 this packet is from the client to the server. 913 """ 914 if skip_unused_packets and not is_a_real_packet(protocol, opcode): 915 return 916 917 src, dest = self.guess_client_server() 918 if not client: 919 src, dest = dest, src 920 key = (protocol, opcode) 921 desc = OP_DESCRIPTIONS.get(key, '') 922 ip_protocol = IP_PROTOCOLS.get(protocol, '06') 923 packet = Packet(timestamp - self.start_time, ip_protocol, 924 '', src, dest, 925 protocol, opcode, desc, extra) 926 # XXX we're assuming the timestamp is already adjusted for 927 # this conversation? 928 # XXX should we adjust client balance for guessed packets? 929 if packet.src == packet.endpoints[0]: 930 self.client_balance -= packet.client_score() 931 else: 932 self.client_balance += packet.client_score() 933 if packet.is_really_a_packet(): 934 self.packets.append(packet) 935 936 def __str__(self): 937 return ("<Conversation %s %s starting %.3f %d packets>" % 938 (self.conversation_id, self.endpoints, self.start_time, 939 len(self.packets))) 940 941 __repr__ = __str__ 942 943 def __iter__(self): 944 return iter(self.packets) 945 946 def __len__(self): 947 return len(self.packets) 948 949 def get_duration(self): 950 if len(self.packets) < 2: 951 return 0 952 return self.packets[-1].timestamp - self.packets[0].timestamp 953 954 def replay_as_summary_lines(self): 955 return [p.as_summary(self.start_time) for p in self.packets] 956 957 def replay_with_delay(self, start, context=None, account=None): 958 """Replay the conversation at the right time. 959 (We're already in a fork).""" 960 # first we sleep until the first packet 961 t = self.start_time 962 now = time.time() - start 963 gap = t - now 964 sleep_time = gap - SLEEP_OVERHEAD 965 if sleep_time > 0: 966 time.sleep(sleep_time) 967 968 miss = (time.time() - start) - t 969 self.msg("starting %s [miss %.3f]" % (self, miss)) 970 971 max_gap = 0.0 972 max_sleep_miss = 0.0 973 # packet times are relative to conversation start 974 p_start = time.time() 975 for p in self.packets: 976 now = time.time() - p_start 977 gap = now - p.timestamp 978 if gap > max_gap: 979 max_gap = gap 980 if gap < 0: 981 sleep_time = -gap - SLEEP_OVERHEAD 982 if sleep_time > 0: 983 time.sleep(sleep_time) 984 t = time.time() - p_start 985 if t - p.timestamp > max_sleep_miss: 986 max_sleep_miss = t - p.timestamp 987 988 p.play(self, context) 989 990 return max_gap, miss, max_sleep_miss 991 992 def guess_client_server(self, server_clue=None): 993 """Have a go at deciding who is the server and who is the client. 994 returns (client, server) 995 """ 996 a, b = self.endpoints 997 998 if self.client_balance < 0: 999 return (a, b) 1000 1001 # in the absense of a clue, we will fall through to assuming 1002 # the lowest number is the server (which is usually true). 1003 1004 if self.client_balance == 0 and server_clue == b: 1005 return (a, b) 1006 1007 return (b, a) 1008 1009 def forget_packets_outside_window(self, s, e): 1010 """Prune any packets outside the timne window we're interested in 1011 1012 :param s: start of the window 1013 :param e: end of the window 1014 """ 1015 self.packets = [p for p in self.packets if s <= p.timestamp <= e] 1016 self.start_time = self.packets[0].timestamp if self.packets else None 1017 1018 def renormalise_times(self, start_time): 1019 """Adjust the packet start times relative to the new start time.""" 1020 for p in self.packets: 1021 p.timestamp -= start_time 1022 1023 if self.start_time is not None: 1024 self.start_time -= start_time 1025 1026 1027class DnsHammer(Conversation): 1028 """A lightweight conversation that generates a lot of dns:0 packets on 1029 the fly""" 1030 1031 def __init__(self, dns_rate, duration, query_file=None): 1032 n = int(dns_rate * duration) 1033 self.times = [random.uniform(0, duration) for i in range(n)] 1034 self.times.sort() 1035 self.rate = dns_rate 1036 self.duration = duration 1037 self.start_time = 0 1038 self.query_choices = self._get_query_choices(query_file=query_file) 1039 1040 def __str__(self): 1041 return ("<DnsHammer %d packets over %.1fs (rate %.2f)>" % 1042 (len(self.times), self.duration, self.rate)) 1043 1044 def _get_query_choices(self, query_file=None): 1045 """ 1046 Read dns query choices from a file, or return default 1047 1048 rname may contain format string like `{realm}` 1049 realm can be fetched from context.realm 1050 """ 1051 1052 if query_file: 1053 with open(query_file, 'r') as f: 1054 text = f.read() 1055 choices = [] 1056 for line in text.splitlines(): 1057 line = line.strip() 1058 if line and not line.startswith('#'): 1059 args = line.split(',') 1060 assert len(args) == 4 1061 choices.append(args) 1062 return choices 1063 else: 1064 return [ 1065 (0, '{realm}', 'A', 'yes'), 1066 (1, '{realm}', 'NS', 'yes'), 1067 (2, '*.{realm}', 'A', 'no'), 1068 (3, '*.{realm}', 'NS', 'no'), 1069 (10, '_msdcs.{realm}', 'A', 'yes'), 1070 (11, '_msdcs.{realm}', 'NS', 'yes'), 1071 (20, 'nx.realm.com', 'A', 'no'), 1072 (21, 'nx.realm.com', 'NS', 'no'), 1073 (22, '*.nx.realm.com', 'A', 'no'), 1074 (23, '*.nx.realm.com', 'NS', 'no'), 1075 ] 1076 1077 def replay(self, context=None): 1078 assert context 1079 assert context.realm 1080 start = time.time() 1081 for t in self.times: 1082 now = time.time() - start 1083 gap = t - now 1084 sleep_time = gap - SLEEP_OVERHEAD 1085 if sleep_time > 0: 1086 time.sleep(sleep_time) 1087 1088 opcode, rname, rtype, exist = random.choice(self.query_choices) 1089 rname = rname.format(realm=context.realm) 1090 success = True 1091 packet_start = time.time() 1092 try: 1093 answers = dns_query(rname, rtype) 1094 if exist == 'yes' and not len(answers): 1095 # expect answers but didn't get, fail 1096 success = False 1097 except Exception: 1098 success = False 1099 finally: 1100 end = time.time() 1101 duration = end - packet_start 1102 print("%f\tDNS\tdns\t%s\t%f\t%s\t" % (end, opcode, duration, success)) 1103 1104 1105def ingest_summaries(files, dns_mode='count'): 1106 """Load a summary traffic summary file and generated Converations from it. 1107 """ 1108 1109 dns_counts = defaultdict(int) 1110 packets = [] 1111 for f in files: 1112 if isinstance(f, str): 1113 f = open(f) 1114 print("Ingesting %s" % (f.name,), file=sys.stderr) 1115 for line in f: 1116 p = Packet.from_line(line) 1117 if p.protocol == 'dns' and dns_mode != 'include': 1118 dns_counts[p.opcode] += 1 1119 else: 1120 packets.append(p) 1121 1122 f.close() 1123 1124 if not packets: 1125 return [], 0 1126 1127 start_time = min(p.timestamp for p in packets) 1128 last_packet = max(p.timestamp for p in packets) 1129 1130 print("gathering packets into conversations", file=sys.stderr) 1131 conversations = OrderedDict() 1132 for i, p in enumerate(packets): 1133 p.timestamp -= start_time 1134 c = conversations.get(p.endpoints) 1135 if c is None: 1136 c = Conversation(conversation_id=(i + 2)) 1137 conversations[p.endpoints] = c 1138 c.add_packet(p) 1139 1140 # We only care about conversations with actual traffic, so we 1141 # filter out conversations with nothing to say. We do that here, 1142 # rather than earlier, because those empty packets contain useful 1143 # hints as to which end of the conversation was the client. 1144 conversation_list = [] 1145 for c in conversations.values(): 1146 if len(c) != 0: 1147 conversation_list.append(c) 1148 1149 # This is obviously not correct, as many conversations will appear 1150 # to start roughly simultaneously at the beginning of the snapshot. 1151 # To which we say: oh well, so be it. 1152 duration = float(last_packet - start_time) 1153 mean_interval = len(conversations) / duration 1154 1155 return conversation_list, mean_interval, duration, dns_counts 1156 1157 1158def guess_server_address(conversations): 1159 # we guess the most common address. 1160 addresses = Counter() 1161 for c in conversations: 1162 addresses.update(c.endpoints) 1163 if addresses: 1164 return addresses.most_common(1)[0] 1165 1166 1167def stringify_keys(x): 1168 y = {} 1169 for k, v in x.items(): 1170 k2 = '\t'.join(k) 1171 y[k2] = v 1172 return y 1173 1174 1175def unstringify_keys(x): 1176 y = {} 1177 for k, v in x.items(): 1178 t = tuple(str(k).split('\t')) 1179 y[t] = v 1180 return y 1181 1182 1183class TrafficModel(object): 1184 def __init__(self, n=3): 1185 self.ngrams = {} 1186 self.query_details = {} 1187 self.n = n 1188 self.dns_opcounts = defaultdict(int) 1189 self.cumulative_duration = 0.0 1190 self.packet_rate = [0, 1] 1191 1192 def learn(self, conversations, dns_opcounts={}): 1193 prev = 0.0 1194 cum_duration = 0.0 1195 key = (NON_PACKET,) * (self.n - 1) 1196 1197 server = guess_server_address(conversations) 1198 1199 for k, v in dns_opcounts.items(): 1200 self.dns_opcounts[k] += v 1201 1202 if len(conversations) > 1: 1203 first = conversations[0].start_time 1204 total = 0 1205 last = first + 0.1 1206 for c in conversations: 1207 total += len(c) 1208 last = max(last, c.packets[-1].timestamp) 1209 1210 self.packet_rate[0] = total 1211 self.packet_rate[1] = last - first 1212 1213 for c in conversations: 1214 client, server = c.guess_client_server(server) 1215 cum_duration += c.get_duration() 1216 key = (NON_PACKET,) * (self.n - 1) 1217 for p in c: 1218 if p.src != client: 1219 continue 1220 1221 elapsed = p.timestamp - prev 1222 prev = p.timestamp 1223 if elapsed > WAIT_THRESHOLD: 1224 # add the wait as an extra state 1225 wait = 'wait:%d' % (math.log(max(1.0, 1226 elapsed * WAIT_SCALE))) 1227 self.ngrams.setdefault(key, []).append(wait) 1228 key = key[1:] + (wait,) 1229 1230 short_p = p.as_packet_type() 1231 self.query_details.setdefault(short_p, 1232 []).append(tuple(p.extra)) 1233 self.ngrams.setdefault(key, []).append(short_p) 1234 key = key[1:] + (short_p,) 1235 1236 self.cumulative_duration += cum_duration 1237 # add in the end 1238 self.ngrams.setdefault(key, []).append(NON_PACKET) 1239 1240 def save(self, f): 1241 ngrams = {} 1242 for k, v in self.ngrams.items(): 1243 k = '\t'.join(k) 1244 ngrams[k] = dict(Counter(v)) 1245 1246 query_details = {} 1247 for k, v in self.query_details.items(): 1248 query_details[k] = dict(Counter('\t'.join(x) if x else '-' 1249 for x in v)) 1250 1251 d = { 1252 'ngrams': ngrams, 1253 'query_details': query_details, 1254 'cumulative_duration': self.cumulative_duration, 1255 'packet_rate': self.packet_rate, 1256 'version': CURRENT_MODEL_VERSION 1257 } 1258 d['dns'] = self.dns_opcounts 1259 1260 if isinstance(f, str): 1261 f = open(f, 'w') 1262 1263 json.dump(d, f, indent=2) 1264 1265 def load(self, f): 1266 if isinstance(f, str): 1267 f = open(f) 1268 1269 d = json.load(f) 1270 1271 try: 1272 version = d["version"] 1273 if version < REQUIRED_MODEL_VERSION: 1274 raise ValueError("the model file is version %d; " 1275 "version %d is required" % 1276 (version, REQUIRED_MODEL_VERSION)) 1277 except KeyError: 1278 raise ValueError("the model file lacks a version number; " 1279 "version %d is required" % 1280 (REQUIRED_MODEL_VERSION)) 1281 1282 for k, v in d['ngrams'].items(): 1283 k = tuple(str(k).split('\t')) 1284 values = self.ngrams.setdefault(k, []) 1285 for p, count in v.items(): 1286 values.extend([str(p)] * count) 1287 values.sort() 1288 1289 for k, v in d['query_details'].items(): 1290 values = self.query_details.setdefault(str(k), []) 1291 for p, count in v.items(): 1292 if p == '-': 1293 values.extend([()] * count) 1294 else: 1295 values.extend([tuple(str(p).split('\t'))] * count) 1296 values.sort() 1297 1298 if 'dns' in d: 1299 for k, v in d['dns'].items(): 1300 self.dns_opcounts[k] += v 1301 1302 self.cumulative_duration = d['cumulative_duration'] 1303 self.packet_rate = d['packet_rate'] 1304 1305 def construct_conversation_sequence(self, timestamp=0.0, 1306 hard_stop=None, 1307 replay_speed=1, 1308 ignore_before=0, 1309 persistence=0): 1310 """Construct an individual conversation packet sequence from the 1311 model. 1312 """ 1313 c = [] 1314 key = (NON_PACKET,) * (self.n - 1) 1315 if ignore_before is None: 1316 ignore_before = timestamp - 1 1317 1318 while True: 1319 p = random.choice(self.ngrams.get(key, (NON_PACKET,))) 1320 if p == NON_PACKET: 1321 if timestamp < ignore_before: 1322 break 1323 if random.random() > persistence: 1324 print("ending after %s (persistence %.1f)" % (key, persistence), 1325 file=sys.stderr) 1326 break 1327 1328 p = 'wait:%d' % random.randrange(5, 12) 1329 print("trying %s instead of end" % p, file=sys.stderr) 1330 1331 if p in self.query_details: 1332 extra = random.choice(self.query_details[p]) 1333 else: 1334 extra = [] 1335 1336 protocol, opcode = p.split(':', 1) 1337 if protocol == 'wait': 1338 log_wait_time = int(opcode) + random.random() 1339 wait = math.exp(log_wait_time) / (WAIT_SCALE * replay_speed) 1340 timestamp += wait 1341 else: 1342 log_wait = random.uniform(*NO_WAIT_LOG_TIME_RANGE) 1343 wait = math.exp(log_wait) / replay_speed 1344 timestamp += wait 1345 if hard_stop is not None and timestamp > hard_stop: 1346 break 1347 if timestamp >= ignore_before: 1348 c.append((timestamp, protocol, opcode, extra)) 1349 1350 key = key[1:] + (p,) 1351 if key[-2][:5] == 'wait:' and key[-1][:5] == 'wait:': 1352 # two waits in a row can only be caused by "persistence" 1353 # tricks, and will not result in any packets being found. 1354 # Instead we pretend this is a fresh start. 1355 key = (NON_PACKET,) * (self.n - 1) 1356 1357 return c 1358 1359 def scale_to_packet_rate(self, scale): 1360 rate_n, rate_t = self.packet_rate 1361 return scale * rate_n / rate_t 1362 1363 def packet_rate_to_scale(self, pps): 1364 rate_n, rate_t = self.packet_rate 1365 return pps * rate_t / rate_n 1366 1367 def generate_conversation_sequences(self, packet_rate, duration, replay_speed=1, 1368 persistence=0): 1369 """Generate a list of conversation descriptions from the model.""" 1370 1371 # We run the simulation for ten times as long as our desired 1372 # duration, and take the section at the end. 1373 lead_in = 9 * duration 1374 target_packets = int(packet_rate * duration) 1375 conversations = [] 1376 n_packets = 0 1377 1378 while n_packets < target_packets: 1379 start = random.uniform(-lead_in, duration) 1380 c = self.construct_conversation_sequence(start, 1381 hard_stop=duration, 1382 replay_speed=replay_speed, 1383 ignore_before=0, 1384 persistence=persistence) 1385 # will these "packets" generate actual traffic? 1386 # some (e.g. ldap unbind) will not generate anything 1387 # if the previous packets are not there, and if the 1388 # conversation only has those it wastes a process doing nothing. 1389 for timestamp, protocol, opcode, extra in c: 1390 if is_a_traffic_generating_packet(protocol, opcode): 1391 break 1392 else: 1393 continue 1394 1395 conversations.append(c) 1396 n_packets += len(c) 1397 1398 scale = self.packet_rate_to_scale(packet_rate) 1399 print(("we have %d packets (target %d) in %d conversations at %.1f/s " 1400 "(scale %f)" % (n_packets, target_packets, len(conversations), 1401 packet_rate, scale)), 1402 file=sys.stderr) 1403 conversations.sort() # sorts by first element == start time 1404 return conversations 1405 1406 1407def seq_to_conversations(seq, server=1, client=2): 1408 conversations = [] 1409 for s in seq: 1410 if s: 1411 c = Conversation(s[0][0], (server, client), s) 1412 client += 1 1413 conversations.append(c) 1414 return conversations 1415 1416 1417IP_PROTOCOLS = { 1418 'dns': '11', 1419 'rpc_netlogon': '06', 1420 'kerberos': '06', # ratio 16248:258 1421 'smb': '06', 1422 'smb2': '06', 1423 'ldap': '06', 1424 'cldap': '11', 1425 'lsarpc': '06', 1426 'samr': '06', 1427 'dcerpc': '06', 1428 'epm': '06', 1429 'drsuapi': '06', 1430 'browser': '11', 1431 'smb_netlogon': '11', 1432 'srvsvc': '06', 1433 'nbns': '11', 1434} 1435 1436OP_DESCRIPTIONS = { 1437 ('browser', '0x01'): 'Host Announcement (0x01)', 1438 ('browser', '0x02'): 'Request Announcement (0x02)', 1439 ('browser', '0x08'): 'Browser Election Request (0x08)', 1440 ('browser', '0x09'): 'Get Backup List Request (0x09)', 1441 ('browser', '0x0c'): 'Domain/Workgroup Announcement (0x0c)', 1442 ('browser', '0x0f'): 'Local Master Announcement (0x0f)', 1443 ('cldap', '3'): 'searchRequest', 1444 ('cldap', '5'): 'searchResDone', 1445 ('dcerpc', '0'): 'Request', 1446 ('dcerpc', '11'): 'Bind', 1447 ('dcerpc', '12'): 'Bind_ack', 1448 ('dcerpc', '13'): 'Bind_nak', 1449 ('dcerpc', '14'): 'Alter_context', 1450 ('dcerpc', '15'): 'Alter_context_resp', 1451 ('dcerpc', '16'): 'AUTH3', 1452 ('dcerpc', '2'): 'Response', 1453 ('dns', '0'): 'query', 1454 ('dns', '1'): 'response', 1455 ('drsuapi', '0'): 'DsBind', 1456 ('drsuapi', '12'): 'DsCrackNames', 1457 ('drsuapi', '13'): 'DsWriteAccountSpn', 1458 ('drsuapi', '1'): 'DsUnbind', 1459 ('drsuapi', '2'): 'DsReplicaSync', 1460 ('drsuapi', '3'): 'DsGetNCChanges', 1461 ('drsuapi', '4'): 'DsReplicaUpdateRefs', 1462 ('epm', '3'): 'Map', 1463 ('kerberos', ''): '', 1464 ('ldap', '0'): 'bindRequest', 1465 ('ldap', '1'): 'bindResponse', 1466 ('ldap', '2'): 'unbindRequest', 1467 ('ldap', '3'): 'searchRequest', 1468 ('ldap', '4'): 'searchResEntry', 1469 ('ldap', '5'): 'searchResDone', 1470 ('ldap', ''): '*** Unknown ***', 1471 ('lsarpc', '14'): 'lsa_LookupNames', 1472 ('lsarpc', '15'): 'lsa_LookupSids', 1473 ('lsarpc', '39'): 'lsa_QueryTrustedDomainInfoBySid', 1474 ('lsarpc', '40'): 'lsa_SetTrustedDomainInfo', 1475 ('lsarpc', '6'): 'lsa_OpenPolicy', 1476 ('lsarpc', '76'): 'lsa_LookupSids3', 1477 ('lsarpc', '77'): 'lsa_LookupNames4', 1478 ('nbns', '0'): 'query', 1479 ('nbns', '1'): 'response', 1480 ('rpc_netlogon', '21'): 'NetrLogonDummyRoutine1', 1481 ('rpc_netlogon', '26'): 'NetrServerAuthenticate3', 1482 ('rpc_netlogon', '29'): 'NetrLogonGetDomainInfo', 1483 ('rpc_netlogon', '30'): 'NetrServerPasswordSet2', 1484 ('rpc_netlogon', '39'): 'NetrLogonSamLogonEx', 1485 ('rpc_netlogon', '40'): 'DsrEnumerateDomainTrusts', 1486 ('rpc_netlogon', '45'): 'NetrLogonSamLogonWithFlags', 1487 ('rpc_netlogon', '4'): 'NetrServerReqChallenge', 1488 ('samr', '0',): 'Connect', 1489 ('samr', '16'): 'GetAliasMembership', 1490 ('samr', '17'): 'LookupNames', 1491 ('samr', '18'): 'LookupRids', 1492 ('samr', '19'): 'OpenGroup', 1493 ('samr', '1'): 'Close', 1494 ('samr', '25'): 'QueryGroupMember', 1495 ('samr', '34'): 'OpenUser', 1496 ('samr', '36'): 'QueryUserInfo', 1497 ('samr', '39'): 'GetGroupsForUser', 1498 ('samr', '3'): 'QuerySecurity', 1499 ('samr', '5'): 'LookupDomain', 1500 ('samr', '64'): 'Connect5', 1501 ('samr', '6'): 'EnumDomains', 1502 ('samr', '7'): 'OpenDomain', 1503 ('samr', '8'): 'QueryDomainInfo', 1504 ('smb', '0x04'): 'Close (0x04)', 1505 ('smb', '0x24'): 'Locking AndX (0x24)', 1506 ('smb', '0x2e'): 'Read AndX (0x2e)', 1507 ('smb', '0x32'): 'Trans2 (0x32)', 1508 ('smb', '0x71'): 'Tree Disconnect (0x71)', 1509 ('smb', '0x72'): 'Negotiate Protocol (0x72)', 1510 ('smb', '0x73'): 'Session Setup AndX (0x73)', 1511 ('smb', '0x74'): 'Logoff AndX (0x74)', 1512 ('smb', '0x75'): 'Tree Connect AndX (0x75)', 1513 ('smb', '0xa2'): 'NT Create AndX (0xa2)', 1514 ('smb2', '0'): 'NegotiateProtocol', 1515 ('smb2', '11'): 'Ioctl', 1516 ('smb2', '14'): 'Find', 1517 ('smb2', '16'): 'GetInfo', 1518 ('smb2', '18'): 'Break', 1519 ('smb2', '1'): 'SessionSetup', 1520 ('smb2', '2'): 'SessionLogoff', 1521 ('smb2', '3'): 'TreeConnect', 1522 ('smb2', '4'): 'TreeDisconnect', 1523 ('smb2', '5'): 'Create', 1524 ('smb2', '6'): 'Close', 1525 ('smb2', '8'): 'Read', 1526 ('smb_netlogon', '0x12'): 'SAM LOGON request from client (0x12)', 1527 ('smb_netlogon', '0x17'): ('SAM Active Directory Response - ' 1528 'user unknown (0x17)'), 1529 ('srvsvc', '16'): 'NetShareGetInfo', 1530 ('srvsvc', '21'): 'NetSrvGetInfo', 1531} 1532 1533 1534def expand_short_packet(p, timestamp, src, dest, extra): 1535 protocol, opcode = p.split(':', 1) 1536 desc = OP_DESCRIPTIONS.get((protocol, opcode), '') 1537 ip_protocol = IP_PROTOCOLS.get(protocol, '06') 1538 1539 line = [timestamp, ip_protocol, '', src, dest, protocol, opcode, desc] 1540 line.extend(extra) 1541 return '\t'.join(line) 1542 1543 1544def flushing_signal_handler(signal, frame): 1545 """Signal handler closes standard out and error. 1546 1547 Triggered by a sigterm, ensures that the log messages are flushed 1548 to disk and not lost. 1549 """ 1550 sys.stderr.close() 1551 sys.stdout.close() 1552 os._exit(0) 1553 1554 1555def replay_seq_in_fork(cs, start, context, account, client_id, server_id=1): 1556 """Fork a new process and replay the conversation sequence.""" 1557 # We will need to reseed the random number generator or all the 1558 # clients will end up using the same sequence of random 1559 # numbers. random.randint() is mixed in so the initial seed will 1560 # have an effect here. 1561 seed = client_id * 1000 + random.randint(0, 999) 1562 1563 # flush our buffers so messages won't be written by both sides 1564 sys.stdout.flush() 1565 sys.stderr.flush() 1566 pid = os.fork() 1567 if pid != 0: 1568 return pid 1569 1570 # we must never return, or we'll end up running parts of the 1571 # parent's clean-up code. So we work in a try...finally, and 1572 # try to print any exceptions. 1573 try: 1574 random.seed(seed) 1575 endpoints = (server_id, client_id) 1576 status = 0 1577 t = cs[0][0] 1578 c = Conversation(t, endpoints, seq=cs, conversation_id=client_id) 1579 signal.signal(signal.SIGTERM, flushing_signal_handler) 1580 1581 context.generate_process_local_config(account, c) 1582 sys.stdin.close() 1583 os.close(0) 1584 filename = os.path.join(context.statsdir, 'stats-conversation-%d' % 1585 c.conversation_id) 1586 f = open(filename, 'w') 1587 try: 1588 sys.stdout.close() 1589 os.close(1) 1590 except IOError as e: 1591 LOGGER.info("stdout closing failed with %s" % e) 1592 pass 1593 1594 sys.stdout = f 1595 now = time.time() - start 1596 gap = t - now 1597 sleep_time = gap - SLEEP_OVERHEAD 1598 if sleep_time > 0: 1599 time.sleep(sleep_time) 1600 1601 max_lag, start_lag, max_sleep_miss = c.replay_with_delay(start=start, 1602 context=context) 1603 print("Maximum lag: %f" % max_lag) 1604 print("Start lag: %f" % start_lag) 1605 print("Max sleep miss: %f" % max_sleep_miss) 1606 1607 except Exception: 1608 status = 1 1609 print(("EXCEPTION in child PID %d, conversation %s" % (os.getpid(), c)), 1610 file=sys.stderr) 1611 traceback.print_exc(sys.stderr) 1612 sys.stderr.flush() 1613 finally: 1614 sys.stderr.close() 1615 sys.stdout.close() 1616 os._exit(status) 1617 1618 1619def dnshammer_in_fork(dns_rate, duration, context, query_file=None): 1620 sys.stdout.flush() 1621 sys.stderr.flush() 1622 pid = os.fork() 1623 if pid != 0: 1624 return pid 1625 1626 sys.stdin.close() 1627 os.close(0) 1628 1629 try: 1630 sys.stdout.close() 1631 os.close(1) 1632 except IOError as e: 1633 LOGGER.warn("stdout closing failed with %s" % e) 1634 pass 1635 filename = os.path.join(context.statsdir, 'stats-dns') 1636 sys.stdout = open(filename, 'w') 1637 1638 try: 1639 status = 0 1640 signal.signal(signal.SIGTERM, flushing_signal_handler) 1641 hammer = DnsHammer(dns_rate, duration, query_file=query_file) 1642 hammer.replay(context=context) 1643 except Exception: 1644 status = 1 1645 print(("EXCEPTION in child PID %d, the DNS hammer" % (os.getpid())), 1646 file=sys.stderr) 1647 traceback.print_exc(sys.stderr) 1648 finally: 1649 sys.stderr.close() 1650 sys.stdout.close() 1651 os._exit(status) 1652 1653 1654def replay(conversation_seq, 1655 host=None, 1656 creds=None, 1657 lp=None, 1658 accounts=None, 1659 dns_rate=0, 1660 dns_query_file=None, 1661 duration=None, 1662 latency_timeout=1.0, 1663 stop_on_any_error=False, 1664 **kwargs): 1665 1666 context = ReplayContext(server=host, 1667 creds=creds, 1668 lp=lp, 1669 total_conversations=len(conversation_seq), 1670 **kwargs) 1671 1672 if len(accounts) < len(conversation_seq): 1673 raise ValueError(("we have %d accounts but %d conversations" % 1674 (len(accounts), len(conversation_seq)))) 1675 1676 # Set the process group so that the calling scripts are not killed 1677 # when the forked child processes are killed. 1678 os.setpgrp() 1679 1680 # we delay the start by a bit to allow all the forks to get up and 1681 # running. 1682 delay = len(conversation_seq) * 0.02 1683 start = time.time() + delay 1684 1685 if duration is None: 1686 # end slightly after the last packet of the last conversation 1687 # to start. Conversations other than the last could still be 1688 # going, but we don't care. 1689 duration = conversation_seq[-1][-1][0] + latency_timeout 1690 1691 print("We will start in %.1f seconds" % delay, 1692 file=sys.stderr) 1693 print("We will stop after %.1f seconds" % (duration + delay), 1694 file=sys.stderr) 1695 print("runtime %.1f seconds" % duration, 1696 file=sys.stderr) 1697 1698 # give one second grace for packets to finish before killing begins 1699 end = start + duration + 1.0 1700 1701 LOGGER.info("Replaying traffic for %u conversations over %d seconds" 1702 % (len(conversation_seq), duration)) 1703 1704 context.write_stats('intentions', 1705 Planned_conversations=len(conversation_seq), 1706 Planned_packets=sum(len(x) for x in conversation_seq)) 1707 1708 children = {} 1709 try: 1710 if dns_rate: 1711 pid = dnshammer_in_fork(dns_rate, duration, context, 1712 query_file=dns_query_file) 1713 children[pid] = 1 1714 1715 for i, cs in enumerate(conversation_seq): 1716 account = accounts[i] 1717 client_id = i + 2 1718 pid = replay_seq_in_fork(cs, start, context, account, client_id) 1719 children[pid] = client_id 1720 1721 # HERE, we are past all the forks 1722 t = time.time() 1723 print("all forks done in %.1f seconds, waiting %.1f" % 1724 (t - start + delay, t - start), 1725 file=sys.stderr) 1726 1727 while time.time() < end and children: 1728 time.sleep(0.003) 1729 try: 1730 pid, status = os.waitpid(-1, os.WNOHANG) 1731 except OSError as e: 1732 if e.errno != ECHILD: # no child processes 1733 raise 1734 break 1735 if pid: 1736 c = children.pop(pid, None) 1737 if DEBUG_LEVEL > 0: 1738 print(("process %d finished conversation %d;" 1739 " %d to go" % 1740 (pid, c, len(children))), file=sys.stderr) 1741 if stop_on_any_error and status != 0: 1742 break 1743 1744 except Exception: 1745 print("EXCEPTION in parent", file=sys.stderr) 1746 traceback.print_exc() 1747 finally: 1748 context.write_stats('unfinished', 1749 Unfinished_conversations=len(children)) 1750 1751 for s in (15, 15, 9): 1752 print(("killing %d children with -%d" % 1753 (len(children), s)), file=sys.stderr) 1754 for pid in children: 1755 try: 1756 os.kill(pid, s) 1757 except OSError as e: 1758 if e.errno != ESRCH: # don't fail if it has already died 1759 raise 1760 time.sleep(0.5) 1761 end = time.time() + 1 1762 while children: 1763 try: 1764 pid, status = os.waitpid(-1, os.WNOHANG) 1765 except OSError as e: 1766 if e.errno != ECHILD: 1767 raise 1768 if pid != 0: 1769 c = children.pop(pid, None) 1770 if c is None: 1771 print("children is %s, no pid found" % children) 1772 sys.stderr.flush() 1773 sys.stdout.flush() 1774 os._exit(1) 1775 print(("kill -%d %d KILLED conversation; " 1776 "%d to go" % 1777 (s, pid, len(children))), 1778 file=sys.stderr) 1779 if time.time() >= end: 1780 break 1781 1782 if not children: 1783 break 1784 time.sleep(1) 1785 1786 if children: 1787 print("%d children are missing" % len(children), 1788 file=sys.stderr) 1789 1790 # there may be stragglers that were forked just as ^C was hit 1791 # and don't appear in the list of children. We can get them 1792 # with killpg, but that will also kill us, so this is^H^H would be 1793 # goodbye, except we cheat and pretend to use ^C (SIG_INTERRUPT), 1794 # so as not to have to fuss around writing signal handlers. 1795 try: 1796 os.killpg(0, 2) 1797 except KeyboardInterrupt: 1798 print("ignoring fake ^C", file=sys.stderr) 1799 1800 1801def openLdb(host, creds, lp): 1802 session = system_session() 1803 ldb = SamDB(url="ldap://%s" % host, 1804 session_info=session, 1805 options=['modules:paged_searches'], 1806 credentials=creds, 1807 lp=lp) 1808 return ldb 1809 1810 1811def ou_name(ldb, instance_id): 1812 """Generate an ou name from the instance id""" 1813 return "ou=instance-%d,ou=traffic_replay,%s" % (instance_id, 1814 ldb.domain_dn()) 1815 1816 1817def create_ou(ldb, instance_id): 1818 """Create an ou, all created user and machine accounts will belong to it. 1819 1820 This allows all the created resources to be cleaned up easily. 1821 """ 1822 ou = ou_name(ldb, instance_id) 1823 try: 1824 ldb.add({"dn": ou.split(',', 1)[1], 1825 "objectclass": "organizationalunit"}) 1826 except LdbError as e: 1827 (status, _) = e.args 1828 # ignore already exists 1829 if status != 68: 1830 raise 1831 try: 1832 ldb.add({"dn": ou, 1833 "objectclass": "organizationalunit"}) 1834 except LdbError as e: 1835 (status, _) = e.args 1836 # ignore already exists 1837 if status != 68: 1838 raise 1839 return ou 1840 1841 1842# ConversationAccounts holds details of the machine and user accounts 1843# associated with a conversation. 1844# 1845# We use a named tuple to reduce shared memory usage. 1846ConversationAccounts = namedtuple('ConversationAccounts', 1847 ('netbios_name', 1848 'machinepass', 1849 'username', 1850 'userpass')) 1851 1852 1853def generate_replay_accounts(ldb, instance_id, number, password): 1854 """Generate a series of unique machine and user account names.""" 1855 1856 accounts = [] 1857 for i in range(1, number + 1): 1858 netbios_name = machine_name(instance_id, i) 1859 username = user_name(instance_id, i) 1860 1861 account = ConversationAccounts(netbios_name, password, username, 1862 password) 1863 accounts.append(account) 1864 return accounts 1865 1866 1867def create_machine_account(ldb, instance_id, netbios_name, machinepass, 1868 traffic_account=True): 1869 """Create a machine account via ldap.""" 1870 1871 ou = ou_name(ldb, instance_id) 1872 dn = "cn=%s,%s" % (netbios_name, ou) 1873 utf16pw = ('"%s"' % get_string(machinepass)).encode('utf-16-le') 1874 1875 if traffic_account: 1876 # we set these bits for the machine account otherwise the replayed 1877 # traffic throws up NT_STATUS_NO_TRUST_SAM_ACCOUNT errors 1878 account_controls = str(UF_TRUSTED_FOR_DELEGATION | 1879 UF_SERVER_TRUST_ACCOUNT) 1880 1881 else: 1882 account_controls = str(UF_WORKSTATION_TRUST_ACCOUNT) 1883 1884 ldb.add({ 1885 "dn": dn, 1886 "objectclass": "computer", 1887 "sAMAccountName": "%s$" % netbios_name, 1888 "userAccountControl": account_controls, 1889 "unicodePwd": utf16pw}) 1890 1891 1892def create_user_account(ldb, instance_id, username, userpass): 1893 """Create a user account via ldap.""" 1894 ou = ou_name(ldb, instance_id) 1895 user_dn = "cn=%s,%s" % (username, ou) 1896 utf16pw = ('"%s"' % get_string(userpass)).encode('utf-16-le') 1897 ldb.add({ 1898 "dn": user_dn, 1899 "objectclass": "user", 1900 "sAMAccountName": username, 1901 "userAccountControl": str(UF_NORMAL_ACCOUNT), 1902 "unicodePwd": utf16pw 1903 }) 1904 1905 # grant user write permission to do things like write account SPN 1906 sdutils = sd_utils.SDUtils(ldb) 1907 sdutils.dacl_add_ace(user_dn, "(A;;WP;;;PS)") 1908 1909 1910def create_group(ldb, instance_id, name): 1911 """Create a group via ldap.""" 1912 1913 ou = ou_name(ldb, instance_id) 1914 dn = "cn=%s,%s" % (name, ou) 1915 ldb.add({ 1916 "dn": dn, 1917 "objectclass": "group", 1918 "sAMAccountName": name, 1919 }) 1920 1921 1922def user_name(instance_id, i): 1923 """Generate a user name based in the instance id""" 1924 return "STGU-%d-%d" % (instance_id, i) 1925 1926 1927def search_objectclass(ldb, objectclass='user', attr='sAMAccountName'): 1928 """Seach objectclass, return attr in a set""" 1929 objs = ldb.search( 1930 expression="(objectClass={})".format(objectclass), 1931 attrs=[attr] 1932 ) 1933 return {str(obj[attr]) for obj in objs} 1934 1935 1936def generate_users(ldb, instance_id, number, password): 1937 """Add users to the server""" 1938 existing_objects = search_objectclass(ldb, objectclass='user') 1939 users = 0 1940 for i in range(number, 0, -1): 1941 name = user_name(instance_id, i) 1942 if name not in existing_objects: 1943 create_user_account(ldb, instance_id, name, password) 1944 users += 1 1945 if users % 50 == 0: 1946 LOGGER.info("Created %u/%u users" % (users, number)) 1947 1948 return users 1949 1950 1951def machine_name(instance_id, i, traffic_account=True): 1952 """Generate a machine account name from instance id.""" 1953 if traffic_account: 1954 # traffic accounts correspond to a given user, and use different 1955 # userAccountControl flags to ensure packets get processed correctly 1956 # by the DC 1957 return "STGM-%d-%d" % (instance_id, i) 1958 else: 1959 # Otherwise we're just generating computer accounts to simulate a 1960 # semi-realistic network. These use the default computer 1961 # userAccountControl flags, so we use a different account name so that 1962 # we don't try to use them when generating packets 1963 return "PC-%d-%d" % (instance_id, i) 1964 1965 1966def generate_machine_accounts(ldb, instance_id, number, password, 1967 traffic_account=True): 1968 """Add machine accounts to the server""" 1969 existing_objects = search_objectclass(ldb, objectclass='computer') 1970 added = 0 1971 for i in range(number, 0, -1): 1972 name = machine_name(instance_id, i, traffic_account) 1973 if name + "$" not in existing_objects: 1974 create_machine_account(ldb, instance_id, name, password, 1975 traffic_account) 1976 added += 1 1977 if added % 50 == 0: 1978 LOGGER.info("Created %u/%u machine accounts" % (added, number)) 1979 1980 return added 1981 1982 1983def group_name(instance_id, i): 1984 """Generate a group name from instance id.""" 1985 return "STGG-%d-%d" % (instance_id, i) 1986 1987 1988def generate_groups(ldb, instance_id, number): 1989 """Create the required number of groups on the server.""" 1990 existing_objects = search_objectclass(ldb, objectclass='group') 1991 groups = 0 1992 for i in range(number, 0, -1): 1993 name = group_name(instance_id, i) 1994 if name not in existing_objects: 1995 create_group(ldb, instance_id, name) 1996 groups += 1 1997 if groups % 1000 == 0: 1998 LOGGER.info("Created %u/%u groups" % (groups, number)) 1999 2000 return groups 2001 2002 2003def clean_up_accounts(ldb, instance_id): 2004 """Remove the created accounts and groups from the server.""" 2005 ou = ou_name(ldb, instance_id) 2006 try: 2007 ldb.delete(ou, ["tree_delete:1"]) 2008 except LdbError as e: 2009 (status, _) = e.args 2010 # ignore does not exist 2011 if status != 32: 2012 raise 2013 2014 2015def generate_users_and_groups(ldb, instance_id, password, 2016 number_of_users, number_of_groups, 2017 group_memberships, max_members, 2018 machine_accounts, traffic_accounts=True): 2019 """Generate the required users and groups, allocating the users to 2020 those groups.""" 2021 memberships_added = 0 2022 groups_added = 0 2023 computers_added = 0 2024 2025 create_ou(ldb, instance_id) 2026 2027 LOGGER.info("Generating dummy user accounts") 2028 users_added = generate_users(ldb, instance_id, number_of_users, password) 2029 2030 LOGGER.info("Generating dummy machine accounts") 2031 computers_added = generate_machine_accounts(ldb, instance_id, 2032 machine_accounts, password, 2033 traffic_accounts) 2034 2035 if number_of_groups > 0: 2036 LOGGER.info("Generating dummy groups") 2037 groups_added = generate_groups(ldb, instance_id, number_of_groups) 2038 2039 if group_memberships > 0: 2040 LOGGER.info("Assigning users to groups") 2041 assignments = GroupAssignments(number_of_groups, 2042 groups_added, 2043 number_of_users, 2044 users_added, 2045 group_memberships, 2046 max_members) 2047 LOGGER.info("Adding users to groups") 2048 add_users_to_groups(ldb, instance_id, assignments) 2049 memberships_added = assignments.total() 2050 2051 if (groups_added > 0 and users_added == 0 and 2052 number_of_groups != groups_added): 2053 LOGGER.warning("The added groups will contain no members") 2054 2055 LOGGER.info("Added %d users (%d machines), %d groups and %d memberships" % 2056 (users_added, computers_added, groups_added, 2057 memberships_added)) 2058 2059 2060class GroupAssignments(object): 2061 def __init__(self, number_of_groups, groups_added, number_of_users, 2062 users_added, group_memberships, max_members): 2063 2064 self.count = 0 2065 self.generate_group_distribution(number_of_groups) 2066 self.generate_user_distribution(number_of_users, group_memberships) 2067 self.max_members = max_members 2068 self.assignments = defaultdict(list) 2069 self.assign_groups(number_of_groups, groups_added, number_of_users, 2070 users_added, group_memberships) 2071 2072 def cumulative_distribution(self, weights): 2073 # make sure the probabilities conform to a cumulative distribution 2074 # spread between 0.0 and 1.0. Dividing by the weighted total gives each 2075 # probability a proportional share of 1.0. Higher probabilities get a 2076 # bigger share, so are more likely to be picked. We use the cumulative 2077 # value, so we can use random.random() as a simple index into the list 2078 dist = [] 2079 total = sum(weights) 2080 if total == 0: 2081 return None 2082 2083 cumulative = 0.0 2084 for probability in weights: 2085 cumulative += probability 2086 dist.append(cumulative / total) 2087 return dist 2088 2089 def generate_user_distribution(self, num_users, num_memberships): 2090 """Probability distribution of a user belonging to a group. 2091 """ 2092 # Assign a weighted probability to each user. Use the Pareto 2093 # Distribution so that some users are in a lot of groups, and the 2094 # bulk of users are in only a few groups. If we're assigning a large 2095 # number of group memberships, use a higher shape. This means slightly 2096 # fewer outlying users that are in large numbers of groups. The aim is 2097 # to have no users belonging to more than ~500 groups. 2098 if num_memberships > 5000000: 2099 shape = 3.0 2100 elif num_memberships > 2000000: 2101 shape = 2.5 2102 elif num_memberships > 300000: 2103 shape = 2.25 2104 else: 2105 shape = 1.75 2106 2107 weights = [] 2108 for x in range(1, num_users + 1): 2109 p = random.paretovariate(shape) 2110 weights.append(p) 2111 2112 # convert the weights to a cumulative distribution between 0.0 and 1.0 2113 self.user_dist = self.cumulative_distribution(weights) 2114 2115 def generate_group_distribution(self, n): 2116 """Probability distribution of a group containing a user.""" 2117 2118 # Assign a weighted probability to each user. Probability decreases 2119 # as the group-ID increases 2120 weights = [] 2121 for x in range(1, n + 1): 2122 p = 1 / (x**1.3) 2123 weights.append(p) 2124 2125 # convert the weights to a cumulative distribution between 0.0 and 1.0 2126 self.group_weights = weights 2127 self.group_dist = self.cumulative_distribution(weights) 2128 2129 def generate_random_membership(self): 2130 """Returns a randomly generated user-group membership""" 2131 2132 # the list items are cumulative distribution values between 0.0 and 2133 # 1.0, which makes random() a handy way to index the list to get a 2134 # weighted random user/group. (Here the user/group returned are 2135 # zero-based array indexes) 2136 user = bisect.bisect(self.user_dist, random.random()) 2137 group = bisect.bisect(self.group_dist, random.random()) 2138 2139 return user, group 2140 2141 def users_in_group(self, group): 2142 return self.assignments[group] 2143 2144 def get_groups(self): 2145 return self.assignments.keys() 2146 2147 def cap_group_membership(self, group, max_members): 2148 """Prevent the group's membership from exceeding the max specified""" 2149 num_members = len(self.assignments[group]) 2150 if num_members >= max_members: 2151 LOGGER.info("Group {0} has {1} members".format(group, num_members)) 2152 2153 # remove this group and then recalculate the cumulative 2154 # distribution, so this group is no longer selected 2155 self.group_weights[group - 1] = 0 2156 new_dist = self.cumulative_distribution(self.group_weights) 2157 self.group_dist = new_dist 2158 2159 def add_assignment(self, user, group): 2160 # the assignments are stored in a dictionary where key=group, 2161 # value=list-of-users-in-group (indexing by group-ID allows us to 2162 # optimize for DB membership writes) 2163 if user not in self.assignments[group]: 2164 self.assignments[group].append(user) 2165 self.count += 1 2166 2167 # check if there'a cap on how big the groups can grow 2168 if self.max_members: 2169 self.cap_group_membership(group, self.max_members) 2170 2171 def assign_groups(self, number_of_groups, groups_added, 2172 number_of_users, users_added, group_memberships): 2173 """Allocate users to groups. 2174 2175 The intention is to have a few users that belong to most groups, while 2176 the majority of users belong to a few groups. 2177 2178 A few groups will contain most users, with the remaining only having a 2179 few users. 2180 """ 2181 2182 if group_memberships <= 0: 2183 return 2184 2185 # Calculate the number of group menberships required 2186 group_memberships = math.ceil( 2187 float(group_memberships) * 2188 (float(users_added) / float(number_of_users))) 2189 2190 if self.max_members: 2191 group_memberships = min(group_memberships, 2192 self.max_members * number_of_groups) 2193 2194 existing_users = number_of_users - users_added - 1 2195 existing_groups = number_of_groups - groups_added - 1 2196 while self.total() < group_memberships: 2197 user, group = self.generate_random_membership() 2198 2199 if group > existing_groups or user > existing_users: 2200 # the + 1 converts the array index to the corresponding 2201 # group or user number 2202 self.add_assignment(user + 1, group + 1) 2203 2204 def total(self): 2205 return self.count 2206 2207 2208def add_users_to_groups(db, instance_id, assignments): 2209 """Takes the assignments of users to groups and applies them to the DB.""" 2210 2211 total = assignments.total() 2212 count = 0 2213 added = 0 2214 2215 for group in assignments.get_groups(): 2216 users_in_group = assignments.users_in_group(group) 2217 if len(users_in_group) == 0: 2218 continue 2219 2220 # Split up the users into chunks, so we write no more than 1K at a 2221 # time. (Minimizing the DB modifies is more efficient, but writing 2222 # 10K+ users to a single group becomes inefficient memory-wise) 2223 for chunk in range(0, len(users_in_group), 1000): 2224 chunk_of_users = users_in_group[chunk:chunk + 1000] 2225 add_group_members(db, instance_id, group, chunk_of_users) 2226 2227 added += len(chunk_of_users) 2228 count += 1 2229 if count % 50 == 0: 2230 LOGGER.info("Added %u/%u memberships" % (added, total)) 2231 2232def add_group_members(db, instance_id, group, users_in_group): 2233 """Adds the given users to group specified.""" 2234 2235 ou = ou_name(db, instance_id) 2236 2237 def build_dn(name): 2238 return("cn=%s,%s" % (name, ou)) 2239 2240 group_dn = build_dn(group_name(instance_id, group)) 2241 m = ldb.Message() 2242 m.dn = ldb.Dn(db, group_dn) 2243 2244 for user in users_in_group: 2245 user_dn = build_dn(user_name(instance_id, user)) 2246 idx = "member-" + str(user) 2247 m[idx] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member") 2248 2249 db.modify(m) 2250 2251 2252def generate_stats(statsdir, timing_file): 2253 """Generate and print the summary stats for a run.""" 2254 first = sys.float_info.max 2255 last = 0 2256 successful = 0 2257 failed = 0 2258 latencies = {} 2259 failures = Counter() 2260 unique_conversations = set() 2261 if timing_file is not None: 2262 tw = timing_file.write 2263 else: 2264 def tw(x): 2265 pass 2266 2267 tw("time\tconv\tprotocol\ttype\tduration\tsuccessful\terror\n") 2268 2269 float_values = { 2270 'Maximum lag': 0, 2271 'Start lag': 0, 2272 'Max sleep miss': 0, 2273 } 2274 int_values = { 2275 'Planned_conversations': 0, 2276 'Planned_packets': 0, 2277 'Unfinished_conversations': 0, 2278 } 2279 2280 for filename in os.listdir(statsdir): 2281 path = os.path.join(statsdir, filename) 2282 with open(path, 'r') as f: 2283 for line in f: 2284 try: 2285 fields = line.rstrip('\n').split('\t') 2286 conversation = fields[1] 2287 protocol = fields[2] 2288 packet_type = fields[3] 2289 latency = float(fields[4]) 2290 t = float(fields[0]) 2291 first = min(t - latency, first) 2292 last = max(t, last) 2293 2294 op = (protocol, packet_type) 2295 latencies.setdefault(op, []).append(latency) 2296 if fields[5] == 'True': 2297 successful += 1 2298 else: 2299 failed += 1 2300 failures[op] += 1 2301 2302 unique_conversations.add(conversation) 2303 2304 tw(line) 2305 except (ValueError, IndexError): 2306 if ':' in line: 2307 k, v = line.split(':', 1) 2308 if k in float_values: 2309 float_values[k] = max(float(v), 2310 float_values[k]) 2311 elif k in int_values: 2312 int_values[k] = max(int(v), 2313 int_values[k]) 2314 else: 2315 print(line, file=sys.stderr) 2316 else: 2317 # not a valid line print and ignore 2318 print(line, file=sys.stderr) 2319 2320 duration = last - first 2321 if successful == 0: 2322 success_rate = 0 2323 else: 2324 success_rate = successful / duration 2325 if failed == 0: 2326 failure_rate = 0 2327 else: 2328 failure_rate = failed / duration 2329 2330 conversations = len(unique_conversations) 2331 2332 print("Total conversations: %10d" % conversations) 2333 print("Successful operations: %10d (%.3f per second)" 2334 % (successful, success_rate)) 2335 print("Failed operations: %10d (%.3f per second)" 2336 % (failed, failure_rate)) 2337 2338 for k, v in sorted(float_values.items()): 2339 print("%-28s %f" % (k.replace('_', ' ') + ':', v)) 2340 for k, v in sorted(int_values.items()): 2341 print("%-28s %d" % (k.replace('_', ' ') + ':', v)) 2342 2343 print("Protocol Op Code Description " 2344 " Count Failed Mean Median " 2345 "95% Range Max") 2346 2347 ops = {} 2348 for proto, packet in latencies: 2349 if proto not in ops: 2350 ops[proto] = set() 2351 ops[proto].add(packet) 2352 protocols = sorted(ops.keys()) 2353 2354 for protocol in protocols: 2355 packet_types = sorted(ops[protocol], key=opcode_key) 2356 for packet_type in packet_types: 2357 op = (protocol, packet_type) 2358 values = latencies[op] 2359 values = sorted(values) 2360 count = len(values) 2361 failed = failures[op] 2362 mean = sum(values) / count 2363 median = calc_percentile(values, 0.50) 2364 percentile = calc_percentile(values, 0.95) 2365 rng = values[-1] - values[0] 2366 maxv = values[-1] 2367 desc = OP_DESCRIPTIONS.get(op, '') 2368 print("%-12s %4s %-35s %12d %12d %12.6f " 2369 "%12.6f %12.6f %12.6f %12.6f" 2370 % (protocol, 2371 packet_type, 2372 desc, 2373 count, 2374 failed, 2375 mean, 2376 median, 2377 percentile, 2378 rng, 2379 maxv)) 2380 2381 2382def opcode_key(v): 2383 """Sort key for the operation code to ensure that it sorts numerically""" 2384 try: 2385 return "%03d" % int(v) 2386 except ValueError: 2387 return v 2388 2389 2390def calc_percentile(values, percentile): 2391 """Calculate the specified percentile from the list of values. 2392 2393 Assumes the list is sorted in ascending order. 2394 """ 2395 2396 if not values: 2397 return 0 2398 k = (len(values) - 1) * percentile 2399 f = math.floor(k) 2400 c = math.ceil(k) 2401 if f == c: 2402 return values[int(k)] 2403 d0 = values[int(f)] * (c - k) 2404 d1 = values[int(c)] * (k - f) 2405 return d0 + d1 2406 2407 2408def mk_masked_dir(*path): 2409 """In a testenv we end up with 0777 directories that look an alarming 2410 green colour with ls. Use umask to avoid that.""" 2411 # py3 os.mkdir can do this 2412 d = os.path.join(*path) 2413 mask = os.umask(0o077) 2414 os.mkdir(d) 2415 os.umask(mask) 2416 return d 2417