1#!~/.wine/drive_c/Python25/python.exe
2# -*- coding: utf-8 -*-
3
4# Copyright (c) 2009-2014, Mario Vilas
5# All rights reserved.
6#
7# Redistribution and use in source and binary forms, with or without
8# modification, are permitted provided that the following conditions are met:
9#
10#     * Redistributions of source code must retain the above copyright notice,
11#       this list of conditions and the following disclaimer.
12#     * Redistributions in binary form must reproduce the above copyright
13#       notice,this list of conditions and the following disclaimer in the
14#       documentation and/or other materials provided with the distribution.
15#     * Neither the name of the copyright holder nor the names of its
16#       contributors may be used to endorse or promote products derived from
17#       this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
23# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
24# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
25# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
26# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
29# POSSIBILITY OF SUCH DAMAGE.
30
31"""
32Module instrumentation.
33
34@group Instrumentation:
35    Module
36
37@group Warnings:
38    DebugSymbolsWarning
39"""
40
41from __future__ import with_statement
42
43__revision__ = "$Id$"
44
45__all__ = ['Module', 'DebugSymbolsWarning']
46
47import sys
48from winappdbg import win32
49from winappdbg import compat
50from winappdbg.textio import HexInput, HexDump
51from winappdbg.util import PathOperations
52
53# delayed imports
54Process = None
55
56import os
57import warnings
58import traceback
59
60#==============================================================================
61
62class DebugSymbolsWarning (UserWarning):
63    """
64    This warning is issued if the support for debug symbols
65    isn't working properly.
66    """
67
68#==============================================================================
69
70class Module (object):
71    """
72    Interface to a DLL library loaded in the context of another process.
73
74    @group Properties:
75        get_base, get_filename, get_name, get_size, get_entry_point,
76        get_process, set_process, get_pid,
77        get_handle, set_handle, open_handle, close_handle
78
79    @group Labels:
80        get_label, get_label_at_address, is_address_here,
81        resolve, resolve_label, match_name
82
83    @group Symbols:
84        load_symbols, unload_symbols, get_symbols, iter_symbols,
85        resolve_symbol, get_symbol_at_address
86
87    @group Modules snapshot:
88        clear
89
90    @type unknown: str
91    @cvar unknown: Suggested tag for unknown modules.
92
93    @type lpBaseOfDll: int
94    @ivar lpBaseOfDll: Base of DLL module.
95        Use L{get_base} instead.
96
97    @type hFile: L{FileHandle}
98    @ivar hFile: Handle to the module file.
99        Use L{get_handle} instead.
100
101    @type fileName: str
102    @ivar fileName: Module filename.
103        Use L{get_filename} instead.
104
105    @type SizeOfImage: int
106    @ivar SizeOfImage: Size of the module.
107        Use L{get_size} instead.
108
109    @type EntryPoint: int
110    @ivar EntryPoint: Entry point of the module.
111        Use L{get_entry_point} instead.
112
113    @type process: L{Process}
114    @ivar process: Process where the module is loaded.
115        Use the L{get_process} method instead.
116    """
117
118    unknown = '<unknown>'
119
120    class _SymbolEnumerator (object):
121        """
122        Internally used by L{Module} to enumerate symbols in a module.
123        """
124
125        def __init__(self, undecorate = False):
126            self.symbols = list()
127            self.undecorate = undecorate
128
129        def __call__(self, SymbolName, SymbolAddress, SymbolSize, UserContext):
130            """
131            Callback that receives symbols and stores them in a Python list.
132            """
133            if self.undecorate:
134                try:
135                    SymbolName = win32.UnDecorateSymbolName(SymbolName)
136                except Exception:
137                    pass # not all symbols are decorated!
138            self.symbols.append( (SymbolName, SymbolAddress, SymbolSize) )
139            return win32.TRUE
140
141    def __init__(self, lpBaseOfDll, hFile = None, fileName    = None,
142                                                  SizeOfImage = None,
143                                                  EntryPoint  = None,
144                                                  process     = None):
145        """
146        @type  lpBaseOfDll: str
147        @param lpBaseOfDll: Base address of the module.
148
149        @type  hFile: L{FileHandle}
150        @param hFile: (Optional) Handle to the module file.
151
152        @type  fileName: str
153        @param fileName: (Optional) Module filename.
154
155        @type  SizeOfImage: int
156        @param SizeOfImage: (Optional) Size of the module.
157
158        @type  EntryPoint: int
159        @param EntryPoint: (Optional) Entry point of the module.
160
161        @type  process: L{Process}
162        @param process: (Optional) Process where the module is loaded.
163        """
164        self.lpBaseOfDll    = lpBaseOfDll
165        self.fileName       = fileName
166        self.SizeOfImage    = SizeOfImage
167        self.EntryPoint     = EntryPoint
168
169        self.__symbols = list()
170
171        self.set_handle(hFile)
172        self.set_process(process)
173
174    # Not really sure if it's a good idea...
175##    def __eq__(self, aModule):
176##        """
177##        Compare two Module objects. The comparison is made using the process
178##        IDs and the module bases.
179##
180##        @type  aModule: L{Module}
181##        @param aModule: Another Module object.
182##
183##        @rtype:  bool
184##        @return: C{True} if the two process IDs and module bases are equal,
185##            C{False} otherwise.
186##        """
187##        return isinstance(aModule, Module)           and \
188##               self.get_pid() == aModule.get_pid()   and \
189##               self.get_base() == aModule.get_base()
190
191    def get_handle(self):
192        """
193        @rtype:  L{Handle}
194        @return: File handle.
195            Returns C{None} if unknown.
196        """
197        # no way to guess!
198        return self.__hFile
199
200    def set_handle(self, hFile):
201        """
202        @type  hFile: L{Handle}
203        @param hFile: File handle. Use C{None} to clear.
204        """
205        if hFile == win32.INVALID_HANDLE_VALUE:
206            hFile = None
207        self.__hFile = hFile
208
209    hFile = property(get_handle, set_handle, doc="")
210
211    def get_process(self):
212        """
213        @rtype:  L{Process}
214        @return: Parent Process object.
215            Returns C{None} if unknown.
216        """
217        # no way to guess!
218        return self.__process
219
220    def set_process(self, process = None):
221        """
222        Manually set the parent process. Use with care!
223
224        @type  process: L{Process}
225        @param process: (Optional) Process object. Use C{None} for no process.
226        """
227        if process is None:
228            self.__process = None
229        else:
230            global Process      # delayed import
231            if Process is None:
232                from winappdbg.process import Process
233            if not isinstance(process, Process):
234                msg  = "Parent process must be a Process instance, "
235                msg += "got %s instead" % type(process)
236                raise TypeError(msg)
237            self.__process = process
238
239    process = property(get_process, set_process, doc="")
240
241    def get_pid(self):
242        """
243        @rtype:  int or None
244        @return: Parent process global ID.
245            Returns C{None} on error.
246        """
247        process = self.get_process()
248        if process is not None:
249            return process.get_pid()
250
251    def get_base(self):
252        """
253        @rtype:  int or None
254        @return: Base address of the module.
255            Returns C{None} if unknown.
256        """
257        return self.lpBaseOfDll
258
259    def get_size(self):
260        """
261        @rtype:  int or None
262        @return: Base size of the module.
263            Returns C{None} if unknown.
264        """
265        if not self.SizeOfImage:
266            self.__get_size_and_entry_point()
267        return self.SizeOfImage
268
269    def get_entry_point(self):
270        """
271        @rtype:  int or None
272        @return: Entry point of the module.
273            Returns C{None} if unknown.
274        """
275        if not self.EntryPoint:
276            self.__get_size_and_entry_point()
277        return self.EntryPoint
278
279    def __get_size_and_entry_point(self):
280        "Get the size and entry point of the module using the Win32 API."
281        process = self.get_process()
282        if process:
283            try:
284                handle = process.get_handle( win32.PROCESS_VM_READ |
285                                             win32.PROCESS_QUERY_INFORMATION )
286                base   = self.get_base()
287                mi     = win32.GetModuleInformation(handle, base)
288                self.SizeOfImage = mi.SizeOfImage
289                self.EntryPoint  = mi.EntryPoint
290            except WindowsError:
291                e = sys.exc_info()[1]
292                warnings.warn(
293                    "Cannot get size and entry point of module %s, reason: %s"\
294                    % (self.get_name(), e.strerror), RuntimeWarning)
295
296    def get_filename(self):
297        """
298        @rtype:  str or None
299        @return: Module filename.
300            Returns C{None} if unknown.
301        """
302        if self.fileName is None:
303            if self.hFile not in (None, win32.INVALID_HANDLE_VALUE):
304                fileName = self.hFile.get_filename()
305                if fileName:
306                    fileName = PathOperations.native_to_win32_pathname(fileName)
307                    self.fileName = fileName
308        return self.fileName
309
310    def __filename_to_modname(self, pathname):
311        """
312        @type  pathname: str
313        @param pathname: Pathname to a module.
314
315        @rtype:  str
316        @return: Module name.
317        """
318        filename = PathOperations.pathname_to_filename(pathname)
319        if filename:
320            filename = filename.lower()
321            filepart, extpart = PathOperations.split_extension(filename)
322            if filepart and extpart:
323                modName = filepart
324            else:
325                modName = filename
326        else:
327            modName = pathname
328        return modName
329
330    def get_name(self):
331        """
332        @rtype:  str
333        @return: Module name, as used in labels.
334
335        @warning: Names are B{NOT} guaranteed to be unique.
336
337            If you need unique identification for a loaded module,
338            use the base address instead.
339
340        @see: L{get_label}
341        """
342        pathname = self.get_filename()
343        if pathname:
344            modName = self.__filename_to_modname(pathname)
345            if isinstance(modName, compat.unicode):
346                try:
347                    modName = modName.encode('cp1252')
348                except UnicodeEncodeError:
349                    e = sys.exc_info()[1]
350                    warnings.warn(str(e))
351        else:
352            modName = "0x%x" % self.get_base()
353        return modName
354
355    def match_name(self, name):
356        """
357        @rtype:  bool
358        @return:
359            C{True} if the given name could refer to this module.
360            It may not be exactly the same returned by L{get_name}.
361        """
362
363        # If the given name is exactly our name, return True.
364        # Comparison is case insensitive.
365        my_name = self.get_name().lower()
366        if name.lower() == my_name:
367            return True
368
369        # If the given name is a base address, compare it with ours.
370        try:
371            base = HexInput.integer(name)
372        except ValueError:
373            base = None
374        if base is not None and base == self.get_base():
375            return True
376
377        # If the given name is a filename, convert it to a module name.
378        # Then compare it with ours, case insensitive.
379        modName = self.__filename_to_modname(name)
380        if modName.lower() == my_name:
381            return True
382
383        # No match.
384        return False
385
386#------------------------------------------------------------------------------
387
388    def open_handle(self):
389        """
390        Opens a new handle to the module.
391
392        The new handle is stored in the L{hFile} property.
393        """
394
395        if not self.get_filename():
396            msg = "Cannot retrieve filename for module at %s"
397            msg = msg % HexDump.address( self.get_base() )
398            raise Exception(msg)
399
400        hFile = win32.CreateFile(self.get_filename(),
401                                           dwShareMode = win32.FILE_SHARE_READ,
402                                 dwCreationDisposition = win32.OPEN_EXISTING)
403
404        # In case hFile was set to an actual handle value instead of a Handle
405        # object. This shouldn't happen unless the user tinkered with hFile.
406        if not hasattr(self.hFile, '__del__'):
407            self.close_handle()
408
409        self.hFile = hFile
410
411    def close_handle(self):
412        """
413        Closes the handle to the module.
414
415        @note: Normally you don't need to call this method. All handles
416            created by I{WinAppDbg} are automatically closed when the garbage
417            collector claims them. So unless you've been tinkering with it,
418            setting L{hFile} to C{None} should be enough.
419        """
420        try:
421            if hasattr(self.hFile, 'close'):
422                self.hFile.close()
423            elif self.hFile not in (None, win32.INVALID_HANDLE_VALUE):
424                win32.CloseHandle(self.hFile)
425        finally:
426            self.hFile = None
427
428    def get_handle(self):
429        """
430        @rtype:  L{FileHandle}
431        @return: Handle to the module file.
432        """
433        if self.hFile in (None, win32.INVALID_HANDLE_VALUE):
434            self.open_handle()
435        return self.hFile
436
437    def clear(self):
438        """
439        Clears the resources held by this object.
440        """
441        try:
442            self.set_process(None)
443        finally:
444            self.close_handle()
445
446#------------------------------------------------------------------------------
447
448    # XXX FIXME
449    # I've been told sometimes the debugging symbols APIs don't correctly
450    # handle redirected exports (for example ws2_32!recv).
451    # I haven't been able to reproduce the bug yet.
452    def load_symbols(self):
453        """
454        Loads the debugging symbols for a module.
455        Automatically called by L{get_symbols}.
456        """
457        if win32.PROCESS_ALL_ACCESS == win32.PROCESS_ALL_ACCESS_VISTA:
458            dwAccess = win32.PROCESS_QUERY_LIMITED_INFORMATION
459        else:
460            dwAccess = win32.PROCESS_QUERY_INFORMATION
461        hProcess     = self.get_process().get_handle(dwAccess)
462        hFile        = self.hFile
463        BaseOfDll    = self.get_base()
464        SizeOfDll    = self.get_size()
465        Enumerator   = self._SymbolEnumerator()
466        try:
467            win32.SymInitialize(hProcess)
468            SymOptions = win32.SymGetOptions()
469            SymOptions |= (
470                win32.SYMOPT_ALLOW_ZERO_ADDRESS     |
471                win32.SYMOPT_CASE_INSENSITIVE       |
472                win32.SYMOPT_FAVOR_COMPRESSED       |
473                win32.SYMOPT_INCLUDE_32BIT_MODULES  |
474                win32.SYMOPT_UNDNAME
475            )
476            SymOptions &= ~(
477                win32.SYMOPT_LOAD_LINES         |
478                win32.SYMOPT_NO_IMAGE_SEARCH    |
479                win32.SYMOPT_NO_CPP             |
480                win32.SYMOPT_IGNORE_NT_SYMPATH
481            )
482            win32.SymSetOptions(SymOptions)
483            try:
484                win32.SymSetOptions(
485                    SymOptions | win32.SYMOPT_ALLOW_ABSOLUTE_SYMBOLS)
486            except WindowsError:
487                pass
488            try:
489                try:
490                    success = win32.SymLoadModule64(
491                        hProcess, hFile, None, None, BaseOfDll, SizeOfDll)
492                except WindowsError:
493                    success = 0
494                if not success:
495                    ImageName = self.get_filename()
496                    success = win32.SymLoadModule64(
497                        hProcess, None, ImageName, None, BaseOfDll, SizeOfDll)
498                if success:
499                    try:
500                        win32.SymEnumerateSymbols64(
501                            hProcess, BaseOfDll, Enumerator)
502                    finally:
503                        win32.SymUnloadModule64(hProcess, BaseOfDll)
504            finally:
505                win32.SymCleanup(hProcess)
506        except WindowsError:
507            e = sys.exc_info()[1]
508            msg = "Cannot load debug symbols for process ID %d, reason:\n%s"
509            msg = msg % (self.get_pid(), traceback.format_exc(e))
510            warnings.warn(msg, DebugSymbolsWarning)
511        self.__symbols = Enumerator.symbols
512
513    def unload_symbols(self):
514        """
515        Unloads the debugging symbols for a module.
516        """
517        self.__symbols = list()
518
519    def get_symbols(self):
520        """
521        Returns the debugging symbols for a module.
522        The symbols are automatically loaded when needed.
523
524        @rtype:  list of tuple( str, int, int )
525        @return: List of symbols.
526            Each symbol is represented by a tuple that contains:
527                - Symbol name
528                - Symbol memory address
529                - Symbol size in bytes
530        """
531        if not self.__symbols:
532            self.load_symbols()
533        return list(self.__symbols)
534
535    def iter_symbols(self):
536        """
537        Returns an iterator for the debugging symbols in a module,
538        in no particular order.
539        The symbols are automatically loaded when needed.
540
541        @rtype:  iterator of tuple( str, int, int )
542        @return: Iterator of symbols.
543            Each symbol is represented by a tuple that contains:
544                - Symbol name
545                - Symbol memory address
546                - Symbol size in bytes
547        """
548        if not self.__symbols:
549            self.load_symbols()
550        return self.__symbols.__iter__()
551
552    def resolve_symbol(self, symbol, bCaseSensitive = False):
553        """
554        Resolves a debugging symbol's address.
555
556        @type  symbol: str
557        @param symbol: Name of the symbol to resolve.
558
559        @type  bCaseSensitive: bool
560        @param bCaseSensitive: C{True} for case sensitive matches,
561            C{False} for case insensitive.
562
563        @rtype:  int or None
564        @return: Memory address of symbol. C{None} if not found.
565        """
566        if bCaseSensitive:
567            for (SymbolName, SymbolAddress, SymbolSize) in self.iter_symbols():
568                if symbol == SymbolName:
569                    return SymbolAddress
570            for (SymbolName, SymbolAddress, SymbolSize) in self.iter_symbols():
571                try:
572                    SymbolName = win32.UnDecorateSymbolName(SymbolName)
573                except Exception:
574                    continue
575                if symbol == SymbolName:
576                    return SymbolAddress
577        else:
578            symbol = symbol.lower()
579            for (SymbolName, SymbolAddress, SymbolSize) in self.iter_symbols():
580                if symbol == SymbolName.lower():
581                    return SymbolAddress
582            for (SymbolName, SymbolAddress, SymbolSize) in self.iter_symbols():
583                try:
584                    SymbolName = win32.UnDecorateSymbolName(SymbolName)
585                except Exception:
586                    continue
587                if symbol == SymbolName.lower():
588                    return SymbolAddress
589
590    def get_symbol_at_address(self, address):
591        """
592        Tries to find the closest matching symbol for the given address.
593
594        @type  address: int
595        @param address: Memory address to query.
596
597        @rtype: None or tuple( str, int, int )
598        @return: Returns a tuple consisting of:
599             - Name
600             - Address
601             - Size (in bytes)
602            Returns C{None} if no symbol could be matched.
603        """
604        found = None
605        for (SymbolName, SymbolAddress, SymbolSize) in self.iter_symbols():
606            if SymbolAddress > address:
607                continue
608            if SymbolAddress + SymbolSize > address:
609                if not found or found[1] < SymbolAddress:
610                    found = (SymbolName, SymbolAddress, SymbolSize)
611        return found
612
613#------------------------------------------------------------------------------
614
615    def get_label(self, function = None, offset = None):
616        """
617        Retrieves the label for the given function of this module or the module
618        base address if no function name is given.
619
620        @type  function: str
621        @param function: (Optional) Exported function name.
622
623        @type  offset: int
624        @param offset: (Optional) Offset from the module base address.
625
626        @rtype:  str
627        @return: Label for the module base address, plus the offset if given.
628        """
629        return _ModuleContainer.parse_label(self.get_name(), function, offset)
630
631    def get_label_at_address(self, address, offset = None):
632        """
633        Creates a label from the given memory address.
634
635        If the address belongs to the module, the label is made relative to
636        it's base address.
637
638        @type  address: int
639        @param address: Memory address.
640
641        @type  offset: None or int
642        @param offset: (Optional) Offset value.
643
644        @rtype:  str
645        @return: Label pointing to the given address.
646        """
647
648        # Add the offset to the address.
649        if offset:
650            address = address + offset
651
652        # Make the label relative to the base address if no match is found.
653        module      = self.get_name()
654        function    = None
655        offset      = address - self.get_base()
656
657        # Make the label relative to the entrypoint if no other match is found.
658        # Skip if the entry point is unknown.
659        start = self.get_entry_point()
660        if start and start <= address:
661            function    = "start"
662            offset      = address - start
663
664        # Enumerate exported functions and debug symbols,
665        # then find the closest match, if possible.
666        try:
667            symbol = self.get_symbol_at_address(address)
668            if symbol:
669                (SymbolName, SymbolAddress, SymbolSize) = symbol
670                new_offset = address - SymbolAddress
671                if new_offset <= offset:
672                    function    = SymbolName
673                    offset      = new_offset
674        except WindowsError:
675            pass
676
677        # Parse the label and return it.
678        return _ModuleContainer.parse_label(module, function, offset)
679
680    def is_address_here(self, address):
681        """
682        Tries to determine if the given address belongs to this module.
683
684        @type  address: int
685        @param address: Memory address.
686
687        @rtype:  bool or None
688        @return: C{True} if the address belongs to the module,
689            C{False} if it doesn't,
690            and C{None} if it can't be determined.
691        """
692        base = self.get_base()
693        size = self.get_size()
694        if base and size:
695            return base <= address < (base + size)
696        return None
697
698    def resolve(self, function):
699        """
700        Resolves a function exported by this module.
701
702        @type  function: str or int
703        @param function:
704            str: Name of the function.
705            int: Ordinal of the function.
706
707        @rtype:  int
708        @return: Memory address of the exported function in the process.
709            Returns None on error.
710        """
711
712        # Unknown DLL filename, there's nothing we can do.
713        filename = self.get_filename()
714        if not filename:
715            return None
716
717        # If the DLL is already mapped locally, resolve the function.
718        try:
719            hlib    = win32.GetModuleHandle(filename)
720            address = win32.GetProcAddress(hlib, function)
721        except WindowsError:
722
723            # Load the DLL locally, resolve the function and unload it.
724            try:
725                hlib = win32.LoadLibraryEx(filename,
726                                           win32.DONT_RESOLVE_DLL_REFERENCES)
727                try:
728                    address = win32.GetProcAddress(hlib, function)
729                finally:
730                    win32.FreeLibrary(hlib)
731            except WindowsError:
732                return None
733
734        # A NULL pointer means the function was not found.
735        if address in (None, 0):
736            return None
737
738        # Compensate for DLL base relocations locally and remotely.
739        return address - hlib + self.lpBaseOfDll
740
741    def resolve_label(self, label):
742        """
743        Resolves a label for this module only. If the label refers to another
744        module, an exception is raised.
745
746        @type  label: str
747        @param label: Label to resolve.
748
749        @rtype:  int
750        @return: Memory address pointed to by the label.
751
752        @raise ValueError: The label is malformed or impossible to resolve.
753        @raise RuntimeError: Cannot resolve the module or function.
754        """
755
756        # Split the label into it's components.
757        # Use the fuzzy mode whenever possible.
758        aProcess = self.get_process()
759        if aProcess is not None:
760            (module, procedure, offset) = aProcess.split_label(label)
761        else:
762            (module, procedure, offset) = _ModuleContainer.split_label(label)
763
764        # If a module name is given that doesn't match ours,
765        # raise an exception.
766        if module and not self.match_name(module):
767            raise RuntimeError("Label does not belong to this module")
768
769        # Resolve the procedure if given.
770        if procedure:
771            address = self.resolve(procedure)
772            if address is None:
773
774                # If it's a debug symbol, use the symbol.
775                address = self.resolve_symbol(procedure)
776
777                # If it's the keyword "start" use the entry point.
778                if address is None and procedure == "start":
779                    address = self.get_entry_point()
780
781                # The procedure was not found.
782                if address is None:
783                    if not module:
784                        module = self.get_name()
785                    msg = "Can't find procedure %s in module %s"
786                    raise RuntimeError(msg % (procedure, module))
787
788        # If no procedure is given use the base address of the module.
789        else:
790            address = self.get_base()
791
792        # Add the offset if given and return the resolved address.
793        if offset:
794            address = address + offset
795        return address
796
797#==============================================================================
798
799# TODO
800# An alternative approach to the toolhelp32 snapshots: parsing the PEB and
801# fetching the list of loaded modules from there. That would solve the problem
802# of toolhelp32 not working when the process hasn't finished initializing.
803# See: http://pferrie.host22.com/misc/lowlevel3.htm
804
805class _ModuleContainer (object):
806    """
807    Encapsulates the capability to contain Module objects.
808
809    @note: Labels are an approximated way of referencing memory locations
810        across different executions of the same process, or different processes
811        with common modules. They are not meant to be perfectly unique, and
812        some errors may occur when multiple modules with the same name are
813        loaded, or when module filenames can't be retrieved.
814
815    @group Modules snapshot:
816        scan_modules,
817        get_module, get_module_bases, get_module_count,
818        get_module_at_address, get_module_by_name,
819        has_module, iter_modules, iter_module_addresses,
820        clear_modules
821
822    @group Labels:
823        parse_label, split_label, sanitize_label, resolve_label,
824        resolve_label_components, get_label_at_address, split_label_strict,
825        split_label_fuzzy
826
827    @group Symbols:
828        load_symbols, unload_symbols, get_symbols, iter_symbols,
829        resolve_symbol, get_symbol_at_address
830
831    @group Debugging:
832        is_system_defined_breakpoint, get_system_breakpoint,
833        get_user_breakpoint, get_breakin_breakpoint,
834        get_wow64_system_breakpoint, get_wow64_user_breakpoint,
835        get_wow64_breakin_breakpoint, get_break_on_error_ptr
836    """
837
838    def __init__(self):
839        self.__moduleDict = dict()
840        self.__system_breakpoints = dict()
841
842        # Replace split_label with the fuzzy version on object instances.
843        self.split_label = self.__use_fuzzy_mode
844
845    def __initialize_snapshot(self):
846        """
847        Private method to automatically initialize the snapshot
848        when you try to use it without calling any of the scan_*
849        methods first. You don't need to call this yourself.
850        """
851        if not self.__moduleDict:
852            try:
853                self.scan_modules()
854            except WindowsError:
855                pass
856
857    def __contains__(self, anObject):
858        """
859        @type  anObject: L{Module}, int
860        @param anObject:
861            - C{Module}: Module object to look for.
862            - C{int}: Base address of the DLL to look for.
863
864        @rtype:  bool
865        @return: C{True} if the snapshot contains
866            a L{Module} object with the same base address.
867        """
868        if isinstance(anObject, Module):
869            anObject = anObject.lpBaseOfDll
870        return self.has_module(anObject)
871
872    def __iter__(self):
873        """
874        @see:    L{iter_modules}
875        @rtype:  dictionary-valueiterator
876        @return: Iterator of L{Module} objects in this snapshot.
877        """
878        return self.iter_modules()
879
880    def __len__(self):
881        """
882        @see:    L{get_module_count}
883        @rtype:  int
884        @return: Count of L{Module} objects in this snapshot.
885        """
886        return self.get_module_count()
887
888    def has_module(self, lpBaseOfDll):
889        """
890        @type  lpBaseOfDll: int
891        @param lpBaseOfDll: Base address of the DLL to look for.
892
893        @rtype:  bool
894        @return: C{True} if the snapshot contains a
895            L{Module} object with the given base address.
896        """
897        self.__initialize_snapshot()
898        return lpBaseOfDll in self.__moduleDict
899
900    def get_module(self, lpBaseOfDll):
901        """
902        @type  lpBaseOfDll: int
903        @param lpBaseOfDll: Base address of the DLL to look for.
904
905        @rtype:  L{Module}
906        @return: Module object with the given base address.
907        """
908        self.__initialize_snapshot()
909        if lpBaseOfDll not in self.__moduleDict:
910            msg = "Unknown DLL base address %s"
911            msg = msg % HexDump.address(lpBaseOfDll)
912            raise KeyError(msg)
913        return self.__moduleDict[lpBaseOfDll]
914
915    def iter_module_addresses(self):
916        """
917        @see:    L{iter_modules}
918        @rtype:  dictionary-keyiterator
919        @return: Iterator of DLL base addresses in this snapshot.
920        """
921        self.__initialize_snapshot()
922        return compat.iterkeys(self.__moduleDict)
923
924    def iter_modules(self):
925        """
926        @see:    L{iter_module_addresses}
927        @rtype:  dictionary-valueiterator
928        @return: Iterator of L{Module} objects in this snapshot.
929        """
930        self.__initialize_snapshot()
931        return compat.itervalues(self.__moduleDict)
932
933    def get_module_bases(self):
934        """
935        @see:    L{iter_module_addresses}
936        @rtype:  list( int... )
937        @return: List of DLL base addresses in this snapshot.
938        """
939        self.__initialize_snapshot()
940        return compat.keys(self.__moduleDict)
941
942    def get_module_count(self):
943        """
944        @rtype:  int
945        @return: Count of L{Module} objects in this snapshot.
946        """
947        self.__initialize_snapshot()
948        return len(self.__moduleDict)
949
950#------------------------------------------------------------------------------
951
952    def get_module_by_name(self, modName):
953        """
954        @type  modName: int
955        @param modName:
956            Name of the module to look for, as returned by L{Module.get_name}.
957            If two or more modules with the same name are loaded, only one
958            of the matching modules is returned.
959
960            You can also pass a full pathname to the DLL file.
961            This works correctly even if two modules with the same name
962            are loaded from different paths.
963
964        @rtype:  L{Module}
965        @return: C{Module} object that best matches the given name.
966            Returns C{None} if no C{Module} can be found.
967        """
968
969        # Convert modName to lowercase.
970        # This helps make case insensitive string comparisons.
971        modName = modName.lower()
972
973        # modName is an absolute pathname.
974        if PathOperations.path_is_absolute(modName):
975            for lib in self.iter_modules():
976                if modName == lib.get_filename().lower():
977                    return lib
978            return None     # Stop trying to match the name.
979
980        # Get all the module names.
981        # This prevents having to iterate through the module list
982        #  more than once.
983        modDict = [ ( lib.get_name(), lib ) for lib in self.iter_modules() ]
984        modDict = dict(modDict)
985
986        # modName is a base filename.
987        if modName in modDict:
988            return modDict[modName]
989
990        # modName is a base filename without extension.
991        filepart, extpart = PathOperations.split_extension(modName)
992        if filepart and extpart:
993            if filepart in modDict:
994                return modDict[filepart]
995
996        # modName is a base address.
997        try:
998            baseAddress = HexInput.integer(modName)
999        except ValueError:
1000            return None
1001        if self.has_module(baseAddress):
1002            return self.get_module(baseAddress)
1003
1004        # Module not found.
1005        return None
1006
1007    def get_module_at_address(self, address):
1008        """
1009        @type  address: int
1010        @param address: Memory address to query.
1011
1012        @rtype:  L{Module}
1013        @return: C{Module} object that best matches the given address.
1014            Returns C{None} if no C{Module} can be found.
1015        """
1016        bases = self.get_module_bases()
1017        bases.sort()
1018        bases.append(long(0x10000000000000000))  # max. 64 bit address + 1
1019        if address >= bases[0]:
1020            i = 0
1021            max_i = len(bases) - 1
1022            while i < max_i:
1023                begin, end = bases[i:i+2]
1024                if begin <= address < end:
1025                    module = self.get_module(begin)
1026                    here   = module.is_address_here(address)
1027                    if here is False:
1028                        break
1029                    else:   # True or None
1030                        return module
1031                i = i + 1
1032        return None
1033
1034    # XXX this method musn't end up calling __initialize_snapshot by accident!
1035    def scan_modules(self):
1036        """
1037        Populates the snapshot with loaded modules.
1038        """
1039
1040        # The module filenames may be spoofed by malware,
1041        # since this information resides in usermode space.
1042        # See: http://www.ragestorm.net/blogs/?p=163
1043
1044        # Ignore special process IDs.
1045        # PID 0: System Idle Process. Also has a special meaning to the
1046        #        toolhelp APIs (current process).
1047        # PID 4: System Integrity Group. See this forum post for more info:
1048        #        http://tinyurl.com/ycza8jo
1049        #        (points to social.technet.microsoft.com)
1050        #        Only on XP and above
1051        # PID 8: System (?) only in Windows 2000 and below AFAIK.
1052        #        It's probably the same as PID 4 in XP and above.
1053        dwProcessId = self.get_pid()
1054        if dwProcessId in (0, 4, 8):
1055            return
1056
1057        # It would seem easier to clear the snapshot first.
1058        # But then all open handles would be closed.
1059        found_bases = set()
1060        with win32.CreateToolhelp32Snapshot(win32.TH32CS_SNAPMODULE,
1061                                            dwProcessId) as hSnapshot:
1062            me = win32.Module32First(hSnapshot)
1063            while me is not None:
1064                lpBaseAddress = me.modBaseAddr
1065                fileName      = me.szExePath    # full pathname
1066                if not fileName:
1067                    fileName  = me.szModule     # filename only
1068                    if not fileName:
1069                        fileName = None
1070                else:
1071                    fileName = PathOperations.native_to_win32_pathname(fileName)
1072                found_bases.add(lpBaseAddress)
1073##                if not self.has_module(lpBaseAddress): # XXX triggers a scan
1074                if lpBaseAddress not in self.__moduleDict:
1075                    aModule = Module(lpBaseAddress, fileName = fileName,
1076                                           SizeOfImage = me.modBaseSize,
1077                                           process = self)
1078                    self._add_module(aModule)
1079                else:
1080                    aModule = self.get_module(lpBaseAddress)
1081                    if not aModule.fileName:
1082                        aModule.fileName    = fileName
1083                    if not aModule.SizeOfImage:
1084                        aModule.SizeOfImage = me.modBaseSize
1085                    if not aModule.process:
1086                        aModule.process     = self
1087                me = win32.Module32Next(hSnapshot)
1088##        for base in self.get_module_bases(): # XXX triggers a scan
1089        for base in compat.keys(self.__moduleDict):
1090            if base not in found_bases:
1091                self._del_module(base)
1092
1093    def clear_modules(self):
1094        """
1095        Clears the modules snapshot.
1096        """
1097        for aModule in compat.itervalues(self.__moduleDict):
1098            aModule.clear()
1099        self.__moduleDict = dict()
1100
1101#------------------------------------------------------------------------------
1102
1103    @staticmethod
1104    def parse_label(module = None, function = None, offset = None):
1105        """
1106        Creates a label from a module and a function name, plus an offset.
1107
1108        @warning: This method only creates the label, it doesn't make sure the
1109            label actually points to a valid memory location.
1110
1111        @type  module: None or str
1112        @param module: (Optional) Module name.
1113
1114        @type  function: None, str or int
1115        @param function: (Optional) Function name or ordinal.
1116
1117        @type  offset: None or int
1118        @param offset: (Optional) Offset value.
1119
1120            If C{function} is specified, offset from the function.
1121
1122            If C{function} is C{None}, offset from the module.
1123
1124        @rtype:  str
1125        @return:
1126            Label representing the given function in the given module.
1127
1128        @raise ValueError:
1129            The module or function name contain invalid characters.
1130        """
1131
1132        # TODO
1133        # Invalid characters should be escaped or filtered.
1134
1135        # Convert ordinals to strings.
1136        try:
1137            function = "#0x%x" % function
1138        except TypeError:
1139            pass
1140
1141        # Validate the parameters.
1142        if module is not None and ('!' in module or '+' in module):
1143            raise ValueError("Invalid module name: %s" % module)
1144        if function is not None and ('!' in function or '+' in function):
1145            raise ValueError("Invalid function name: %s" % function)
1146
1147        # Parse the label.
1148        if module:
1149            if function:
1150                if offset:
1151                    label = "%s!%s+0x%x" % (module, function, offset)
1152                else:
1153                    label = "%s!%s" % (module, function)
1154            else:
1155                if offset:
1156##                    label = "%s+0x%x!" % (module, offset)
1157                    label = "%s!0x%x" % (module, offset)
1158                else:
1159                    label = "%s!" % module
1160        else:
1161            if function:
1162                if offset:
1163                    label = "!%s+0x%x" % (function, offset)
1164                else:
1165                    label = "!%s" % function
1166            else:
1167                if offset:
1168                    label = "0x%x" % offset
1169                else:
1170                    label = "0x0"
1171
1172        return label
1173
1174    @staticmethod
1175    def split_label_strict(label):
1176        """
1177        Splits a label created with L{parse_label}.
1178
1179        To parse labels with a less strict syntax, use the L{split_label_fuzzy}
1180        method instead.
1181
1182        @warning: This method only parses the label, it doesn't make sure the
1183            label actually points to a valid memory location.
1184
1185        @type  label: str
1186        @param label: Label to split.
1187
1188        @rtype:  tuple( str or None, str or int or None, int or None )
1189        @return: Tuple containing the C{module} name,
1190            the C{function} name or ordinal, and the C{offset} value.
1191
1192            If the label doesn't specify a module,
1193            then C{module} is C{None}.
1194
1195            If the label doesn't specify a function,
1196            then C{function} is C{None}.
1197
1198            If the label doesn't specify an offset,
1199            then C{offset} is C{0}.
1200
1201        @raise ValueError: The label is malformed.
1202        """
1203        module = function = None
1204        offset = 0
1205
1206        # Special case: None
1207        if not label:
1208            label = "0x0"
1209        else:
1210
1211            # Remove all blanks.
1212            label = label.replace(' ', '')
1213            label = label.replace('\t', '')
1214            label = label.replace('\r', '')
1215            label = label.replace('\n', '')
1216
1217            # Special case: empty label.
1218            if not label:
1219                label = "0x0"
1220
1221        # * ! *
1222        if '!' in label:
1223            try:
1224                module, function = label.split('!')
1225            except ValueError:
1226                raise ValueError("Malformed label: %s" % label)
1227
1228            # module ! function
1229            if function:
1230                if '+' in module:
1231                    raise ValueError("Malformed label: %s" % label)
1232
1233                # module ! function + offset
1234                if '+' in function:
1235                    try:
1236                        function, offset = function.split('+')
1237                    except ValueError:
1238                        raise ValueError("Malformed label: %s" % label)
1239                    try:
1240                        offset = HexInput.integer(offset)
1241                    except ValueError:
1242                        raise ValueError("Malformed label: %s" % label)
1243                else:
1244
1245                    # module ! offset
1246                    try:
1247                        offset   = HexInput.integer(function)
1248                        function = None
1249                    except ValueError:
1250                        pass
1251            else:
1252
1253                # module + offset !
1254                if '+' in module:
1255                    try:
1256                        module, offset = module.split('+')
1257                    except ValueError:
1258                        raise ValueError("Malformed label: %s" % label)
1259                    try:
1260                        offset = HexInput.integer(offset)
1261                    except ValueError:
1262                        raise ValueError("Malformed label: %s" % label)
1263
1264                else:
1265
1266                    # module !
1267                    try:
1268                        offset = HexInput.integer(module)
1269                        module = None
1270
1271                    # offset !
1272                    except ValueError:
1273                        pass
1274
1275            if not module:
1276                module   = None
1277            if not function:
1278                function = None
1279
1280        # *
1281        else:
1282
1283            # offset
1284            try:
1285                offset = HexInput.integer(label)
1286
1287            # # ordinal
1288            except ValueError:
1289                if label.startswith('#'):
1290                    function = label
1291                    try:
1292                        HexInput.integer(function[1:])
1293
1294                    # module?
1295                    # function?
1296                    except ValueError:
1297                        raise ValueError("Ambiguous label: %s" % label)
1298
1299                # module?
1300                # function?
1301                else:
1302                    raise ValueError("Ambiguous label: %s" % label)
1303
1304        # Convert function ordinal strings into integers.
1305        if function and function.startswith('#'):
1306            try:
1307                function = HexInput.integer(function[1:])
1308            except ValueError:
1309                pass
1310
1311        # Convert null offsets to None.
1312        if not offset:
1313            offset = None
1314
1315        return (module, function, offset)
1316
1317    def split_label_fuzzy(self, label):
1318        """
1319        Splits a label entered as user input.
1320
1321        It's more flexible in it's syntax parsing than the L{split_label_strict}
1322        method, as it allows the exclamation mark (B{C{!}}) to be omitted. The
1323        ambiguity is resolved by searching the modules in the snapshot to guess
1324        if a label refers to a module or a function. It also tries to rebuild
1325        labels when they contain hardcoded addresses.
1326
1327        @warning: This method only parses the label, it doesn't make sure the
1328            label actually points to a valid memory location.
1329
1330        @type  label: str
1331        @param label: Label to split.
1332
1333        @rtype:  tuple( str or None, str or int or None, int or None )
1334        @return: Tuple containing the C{module} name,
1335            the C{function} name or ordinal, and the C{offset} value.
1336
1337            If the label doesn't specify a module,
1338            then C{module} is C{None}.
1339
1340            If the label doesn't specify a function,
1341            then C{function} is C{None}.
1342
1343            If the label doesn't specify an offset,
1344            then C{offset} is C{0}.
1345
1346        @raise ValueError: The label is malformed.
1347        """
1348        module = function = None
1349        offset = 0
1350
1351        # Special case: None
1352        if not label:
1353            label = compat.b("0x0")
1354        else:
1355
1356            # Remove all blanks.
1357            label = label.replace(compat.b(' '), compat.b(''))
1358            label = label.replace(compat.b('\t'), compat.b(''))
1359            label = label.replace(compat.b('\r'), compat.b(''))
1360            label = label.replace(compat.b('\n'), compat.b(''))
1361
1362            # Special case: empty label.
1363            if not label:
1364                label = compat.b("0x0")
1365
1366        # If an exclamation sign is present, we know we can parse it strictly.
1367        if compat.b('!') in label:
1368            return self.split_label_strict(label)
1369
1370##        # Try to parse it strictly, on error do it the fuzzy way.
1371##        try:
1372##            return self.split_label(label)
1373##        except ValueError:
1374##            pass
1375
1376        # * + offset
1377        if compat.b('+') in label:
1378            try:
1379                prefix, offset = label.split(compat.b('+'))
1380            except ValueError:
1381                raise ValueError("Malformed label: %s" % label)
1382            try:
1383                offset = HexInput.integer(offset)
1384            except ValueError:
1385                raise ValueError("Malformed label: %s" % label)
1386            label = prefix
1387
1388        # This parses both filenames and base addresses.
1389        modobj = self.get_module_by_name(label)
1390        if modobj:
1391
1392            # module
1393            # module + offset
1394            module = modobj.get_name()
1395
1396        else:
1397
1398            # TODO
1399            # If 0xAAAAAAAA + 0xBBBBBBBB is given,
1400            # A is interpreted as a module base address,
1401            # and B as an offset.
1402            # If that fails, it'd be good to add A+B and try to
1403            # use the nearest loaded module.
1404
1405            # offset
1406            # base address + offset (when no module has that base address)
1407            try:
1408                address = HexInput.integer(label)
1409
1410                if offset:
1411                    # If 0xAAAAAAAA + 0xBBBBBBBB is given,
1412                    # A is interpreted as a module base address,
1413                    # and B as an offset.
1414                    # If that fails, we get here, meaning no module was found
1415                    # at A. Then add up A+B and work with that as a hardcoded
1416                    # address.
1417                    offset = address + offset
1418                else:
1419                    # If the label is a hardcoded address, we get here.
1420                    offset = address
1421
1422                # If only a hardcoded address is given,
1423                # rebuild the label using get_label_at_address.
1424                # Then parse it again, but this time strictly,
1425                # both because there is no need for fuzzy syntax and
1426                # to prevent an infinite recursion if there's a bug here.
1427                try:
1428                    new_label = self.get_label_at_address(offset)
1429                    module, function, offset = \
1430                                             self.split_label_strict(new_label)
1431                except ValueError:
1432                    pass
1433
1434            # function
1435            # function + offset
1436            except ValueError:
1437                function = label
1438
1439        # Convert function ordinal strings into integers.
1440        if function and function.startswith(compat.b('#')):
1441            try:
1442                function = HexInput.integer(function[1:])
1443            except ValueError:
1444                pass
1445
1446        # Convert null offsets to None.
1447        if not offset:
1448            offset = None
1449
1450        return (module, function, offset)
1451
1452    @classmethod
1453    def split_label(cls, label):
1454        """
1455Splits a label into it's C{module}, C{function} and C{offset}
1456components, as used in L{parse_label}.
1457
1458When called as a static method, the strict syntax mode is used::
1459
1460    winappdbg.Process.split_label( "kernel32!CreateFileA" )
1461
1462When called as an instance method, the fuzzy syntax mode is used::
1463
1464    aProcessInstance.split_label( "CreateFileA" )
1465
1466@see: L{split_label_strict}, L{split_label_fuzzy}
1467
1468@type  label: str
1469@param label: Label to split.
1470
1471@rtype:  tuple( str or None, str or int or None, int or None )
1472@return:
1473    Tuple containing the C{module} name,
1474    the C{function} name or ordinal, and the C{offset} value.
1475
1476    If the label doesn't specify a module,
1477    then C{module} is C{None}.
1478
1479    If the label doesn't specify a function,
1480    then C{function} is C{None}.
1481
1482    If the label doesn't specify an offset,
1483    then C{offset} is C{0}.
1484
1485@raise ValueError: The label is malformed.
1486        """
1487
1488        # XXX
1489        # Docstring indentation was removed so epydoc doesn't complain
1490        # when parsing the docs for __use_fuzzy_mode().
1491
1492        # This function is overwritten by __init__
1493        # so here is the static implementation only.
1494        return cls.split_label_strict(label)
1495
1496    # The split_label method is replaced with this function by __init__.
1497    def __use_fuzzy_mode(self, label):
1498        "@see: L{split_label}"
1499        return self.split_label_fuzzy(label)
1500##    __use_fuzzy_mode.__doc__ = split_label.__doc__
1501
1502    def sanitize_label(self, label):
1503        """
1504        Converts a label taken from user input into a well-formed label.
1505
1506        @type  label: str
1507        @param label: Label taken from user input.
1508
1509        @rtype:  str
1510        @return: Sanitized label.
1511        """
1512        (module, function, offset) = self.split_label_fuzzy(label)
1513        label = self.parse_label(module, function, offset)
1514        return label
1515
1516    def resolve_label(self, label):
1517        """
1518        Resolve the memory address of the given label.
1519
1520        @note:
1521            If multiple modules with the same name are loaded,
1522            the label may be resolved at any of them. For a more precise
1523            way to resolve functions use the base address to get the L{Module}
1524            object (see L{Process.get_module}) and then call L{Module.resolve}.
1525
1526            If no module name is specified in the label, the function may be
1527            resolved in any loaded module. If you want to resolve all functions
1528            with that name in all processes, call L{Process.iter_modules} to
1529            iterate through all loaded modules, and then try to resolve the
1530            function in each one of them using L{Module.resolve}.
1531
1532        @type  label: str
1533        @param label: Label to resolve.
1534
1535        @rtype:  int
1536        @return: Memory address pointed to by the label.
1537
1538        @raise ValueError: The label is malformed or impossible to resolve.
1539        @raise RuntimeError: Cannot resolve the module or function.
1540        """
1541
1542        # Split the label into module, function and offset components.
1543        module, function, offset = self.split_label_fuzzy(label)
1544
1545        # Resolve the components into a memory address.
1546        address = self.resolve_label_components(module, function, offset)
1547
1548        # Return the memory address.
1549        return address
1550
1551    def resolve_label_components(self, module   = None,
1552                                       function = None,
1553                                       offset   = None):
1554        """
1555        Resolve the memory address of the given module, function and/or offset.
1556
1557        @note:
1558            If multiple modules with the same name are loaded,
1559            the label may be resolved at any of them. For a more precise
1560            way to resolve functions use the base address to get the L{Module}
1561            object (see L{Process.get_module}) and then call L{Module.resolve}.
1562
1563            If no module name is specified in the label, the function may be
1564            resolved in any loaded module. If you want to resolve all functions
1565            with that name in all processes, call L{Process.iter_modules} to
1566            iterate through all loaded modules, and then try to resolve the
1567            function in each one of them using L{Module.resolve}.
1568
1569        @type  module: None or str
1570        @param module: (Optional) Module name.
1571
1572        @type  function: None, str or int
1573        @param function: (Optional) Function name or ordinal.
1574
1575        @type  offset: None or int
1576        @param offset: (Optional) Offset value.
1577
1578            If C{function} is specified, offset from the function.
1579
1580            If C{function} is C{None}, offset from the module.
1581
1582        @rtype:  int
1583        @return: Memory address pointed to by the label.
1584
1585        @raise ValueError: The label is malformed or impossible to resolve.
1586        @raise RuntimeError: Cannot resolve the module or function.
1587        """
1588        # Default address if no module or function are given.
1589        # An offset may be added later.
1590        address = 0
1591
1592        # Resolve the module.
1593        # If the module is not found, check for the special symbol "main".
1594        if module:
1595            modobj = self.get_module_by_name(module)
1596            if not modobj:
1597                if module == "main":
1598                    modobj = self.get_main_module()
1599                else:
1600                    raise RuntimeError("Module %r not found" % module)
1601
1602            # Resolve the exported function or debugging symbol.
1603            # If all else fails, check for the special symbol "start".
1604            if function:
1605                address = modobj.resolve(function)
1606                if address is None:
1607                    address = modobj.resolve_symbol(function)
1608                    if address is None:
1609                        if function == "start":
1610                            address = modobj.get_entry_point()
1611                        if address is None:
1612                            msg = "Symbol %r not found in module %s"
1613                            raise RuntimeError(msg % (function, module))
1614
1615            # No function, use the base address.
1616            else:
1617                address = modobj.get_base()
1618
1619        # Resolve the function in any module.
1620        # If all else fails, check for the special symbols "main" and "start".
1621        elif function:
1622            for modobj in self.iter_modules():
1623                address = modobj.resolve(function)
1624                if address is not None:
1625                    break
1626            if address is None:
1627                if function == "start":
1628                    modobj = self.get_main_module()
1629                    address = modobj.get_entry_point()
1630                elif function == "main":
1631                    modobj = self.get_main_module()
1632                    address = modobj.get_base()
1633                else:
1634                    msg = "Function %r not found in any module" % function
1635                    raise RuntimeError(msg)
1636
1637        # Return the address plus the offset.
1638        if offset:
1639            address = address + offset
1640        return address
1641
1642    def get_label_at_address(self, address, offset = None):
1643        """
1644        Creates a label from the given memory address.
1645
1646        @warning: This method uses the name of the nearest currently loaded
1647            module. If that module is unloaded later, the label becomes
1648            impossible to resolve.
1649
1650        @type  address: int
1651        @param address: Memory address.
1652
1653        @type  offset: None or int
1654        @param offset: (Optional) Offset value.
1655
1656        @rtype:  str
1657        @return: Label pointing to the given address.
1658        """
1659        if offset:
1660            address = address + offset
1661        modobj = self.get_module_at_address(address)
1662        if modobj:
1663            label = modobj.get_label_at_address(address)
1664        else:
1665            label = self.parse_label(None, None, address)
1666        return label
1667
1668#------------------------------------------------------------------------------
1669
1670    # The memory addresses of system breakpoints are be cached, since they're
1671    # all in system libraries it's not likely they'll ever change their address
1672    # during the lifetime of the process... I don't suppose a program could
1673    # happily unload ntdll.dll and survive.
1674    def __get_system_breakpoint(self, label):
1675        try:
1676            return self.__system_breakpoints[label]
1677        except KeyError:
1678            try:
1679                address = self.resolve_label(label)
1680            except Exception:
1681                return None
1682            self.__system_breakpoints[label] = address
1683            return address
1684
1685    # It's in kernel32 in Windows Server 2003, in ntdll since Windows Vista.
1686    # It can only be resolved if we have the debug symbols.
1687    def get_break_on_error_ptr(self):
1688        """
1689        @rtype: int
1690        @return:
1691            If present, returns the address of the C{g_dwLastErrorToBreakOn}
1692            global variable for this process. If not, returns C{None}.
1693        """
1694        address = self.__get_system_breakpoint("ntdll!g_dwLastErrorToBreakOn")
1695        if not address:
1696            address = self.__get_system_breakpoint(
1697                                            "kernel32!g_dwLastErrorToBreakOn")
1698            # cheat a little :)
1699            self.__system_breakpoints["ntdll!g_dwLastErrorToBreakOn"] = address
1700        return address
1701
1702    def is_system_defined_breakpoint(self, address):
1703        """
1704        @type  address: int
1705        @param address: Memory address.
1706
1707        @rtype:  bool
1708        @return: C{True} if the given address points to a system defined
1709            breakpoint. System defined breakpoints are hardcoded into
1710            system libraries.
1711        """
1712        if address:
1713            module = self.get_module_at_address(address)
1714            if module:
1715                return module.match_name("ntdll")    or \
1716                       module.match_name("kernel32")
1717        return False
1718
1719    # FIXME
1720    # In Wine, the system breakpoint seems to be somewhere in kernel32.
1721    def get_system_breakpoint(self):
1722        """
1723        @rtype:  int or None
1724        @return: Memory address of the system breakpoint
1725            within the process address space.
1726            Returns C{None} on error.
1727        """
1728        return self.__get_system_breakpoint("ntdll!DbgBreakPoint")
1729
1730    # I don't know when this breakpoint is actually used...
1731    def get_user_breakpoint(self):
1732        """
1733        @rtype:  int or None
1734        @return: Memory address of the user breakpoint
1735            within the process address space.
1736            Returns C{None} on error.
1737        """
1738        return self.__get_system_breakpoint("ntdll!DbgUserBreakPoint")
1739
1740    # On some platforms, this breakpoint can only be resolved
1741    # when the debugging symbols for ntdll.dll are loaded.
1742    def get_breakin_breakpoint(self):
1743        """
1744        @rtype:  int or None
1745        @return: Memory address of the remote breakin breakpoint
1746            within the process address space.
1747            Returns C{None} on error.
1748        """
1749        return self.__get_system_breakpoint("ntdll!DbgUiRemoteBreakin")
1750
1751    # Equivalent of ntdll!DbgBreakPoint in Wow64.
1752    def get_wow64_system_breakpoint(self):
1753        """
1754        @rtype:  int or None
1755        @return: Memory address of the Wow64 system breakpoint
1756            within the process address space.
1757            Returns C{None} on error.
1758        """
1759        return self.__get_system_breakpoint("ntdll32!DbgBreakPoint")
1760
1761    # Equivalent of ntdll!DbgUserBreakPoint in Wow64.
1762    def get_wow64_user_breakpoint(self):
1763        """
1764        @rtype:  int or None
1765        @return: Memory address of the Wow64 user breakpoint
1766            within the process address space.
1767            Returns C{None} on error.
1768        """
1769        return self.__get_system_breakpoint("ntdll32!DbgUserBreakPoint")
1770
1771    # Equivalent of ntdll!DbgUiRemoteBreakin in Wow64.
1772    def get_wow64_breakin_breakpoint(self):
1773        """
1774        @rtype:  int or None
1775        @return: Memory address of the Wow64 remote breakin breakpoint
1776            within the process address space.
1777            Returns C{None} on error.
1778        """
1779        return self.__get_system_breakpoint("ntdll32!DbgUiRemoteBreakin")
1780
1781#------------------------------------------------------------------------------
1782
1783    def load_symbols(self):
1784        """
1785        Loads the debugging symbols for all modules in this snapshot.
1786        Automatically called by L{get_symbols}.
1787        """
1788        for aModule in self.iter_modules():
1789            aModule.load_symbols()
1790
1791    def unload_symbols(self):
1792        """
1793        Unloads the debugging symbols for all modules in this snapshot.
1794        """
1795        for aModule in self.iter_modules():
1796            aModule.unload_symbols()
1797
1798    def get_symbols(self):
1799        """
1800        Returns the debugging symbols for all modules in this snapshot.
1801        The symbols are automatically loaded when needed.
1802
1803        @rtype:  list of tuple( str, int, int )
1804        @return: List of symbols.
1805            Each symbol is represented by a tuple that contains:
1806                - Symbol name
1807                - Symbol memory address
1808                - Symbol size in bytes
1809        """
1810        symbols = list()
1811        for aModule in self.iter_modules():
1812            for symbol in aModule.iter_symbols():
1813                symbols.append(symbol)
1814        return symbols
1815
1816    def iter_symbols(self):
1817        """
1818        Returns an iterator for the debugging symbols in all modules in this
1819        snapshot, in no particular order.
1820        The symbols are automatically loaded when needed.
1821
1822        @rtype:  iterator of tuple( str, int, int )
1823        @return: Iterator of symbols.
1824            Each symbol is represented by a tuple that contains:
1825                - Symbol name
1826                - Symbol memory address
1827                - Symbol size in bytes
1828        """
1829        for aModule in self.iter_modules():
1830            for symbol in aModule.iter_symbols():
1831                yield symbol
1832
1833    def resolve_symbol(self, symbol, bCaseSensitive = False):
1834        """
1835        Resolves a debugging symbol's address.
1836
1837        @type  symbol: str
1838        @param symbol: Name of the symbol to resolve.
1839
1840        @type  bCaseSensitive: bool
1841        @param bCaseSensitive: C{True} for case sensitive matches,
1842            C{False} for case insensitive.
1843
1844        @rtype:  int or None
1845        @return: Memory address of symbol. C{None} if not found.
1846        """
1847        if bCaseSensitive:
1848            for (SymbolName, SymbolAddress, SymbolSize) in self.iter_symbols():
1849                if symbol == SymbolName:
1850                    return SymbolAddress
1851        else:
1852            symbol = symbol.lower()
1853            for (SymbolName, SymbolAddress, SymbolSize) in self.iter_symbols():
1854                if symbol == SymbolName.lower():
1855                    return SymbolAddress
1856
1857    def get_symbol_at_address(self, address):
1858        """
1859        Tries to find the closest matching symbol for the given address.
1860
1861        @type  address: int
1862        @param address: Memory address to query.
1863
1864        @rtype: None or tuple( str, int, int )
1865        @return: Returns a tuple consisting of:
1866             - Name
1867             - Address
1868             - Size (in bytes)
1869            Returns C{None} if no symbol could be matched.
1870        """
1871        # Any module may have symbols pointing anywhere in memory, so there's
1872        # no easy way to optimize this. I guess we're stuck with brute force.
1873        found = None
1874        for (SymbolName, SymbolAddress, SymbolSize) in self.iter_symbols():
1875            if SymbolAddress > address:
1876                continue
1877
1878            if SymbolAddress == address:
1879                found = (SymbolName, SymbolAddress, SymbolSize)
1880                break
1881
1882            if SymbolAddress < address:
1883                if found and (address - found[1]) < (address - SymbolAddress):
1884                    continue
1885                else:
1886                    found = (SymbolName, SymbolAddress, SymbolSize)
1887        return found
1888#------------------------------------------------------------------------------
1889
1890    # XXX _notify_* methods should not trigger a scan
1891
1892    def _add_module(self, aModule):
1893        """
1894        Private method to add a module object to the snapshot.
1895
1896        @type  aModule: L{Module}
1897        @param aModule: Module object.
1898        """
1899##        if not isinstance(aModule, Module):
1900##            if hasattr(aModule, '__class__'):
1901##                typename = aModule.__class__.__name__
1902##            else:
1903##                typename = str(type(aModule))
1904##            msg = "Expected Module, got %s instead" % typename
1905##            raise TypeError(msg)
1906        lpBaseOfDll = aModule.get_base()
1907##        if lpBaseOfDll in self.__moduleDict:
1908##            msg = "Module already exists: %d" % lpBaseOfDll
1909##            raise KeyError(msg)
1910        aModule.set_process(self)
1911        self.__moduleDict[lpBaseOfDll] = aModule
1912
1913    def _del_module(self, lpBaseOfDll):
1914        """
1915        Private method to remove a module object from the snapshot.
1916
1917        @type  lpBaseOfDll: int
1918        @param lpBaseOfDll: Module base address.
1919        """
1920        try:
1921            aModule = self.__moduleDict[lpBaseOfDll]
1922            del self.__moduleDict[lpBaseOfDll]
1923        except KeyError:
1924            aModule = None
1925            msg = "Unknown base address %d" % HexDump.address(lpBaseOfDll)
1926            warnings.warn(msg, RuntimeWarning)
1927        if aModule:
1928            aModule.clear()     # remove circular references
1929
1930    def __add_loaded_module(self, event):
1931        """
1932        Private method to automatically add new module objects from debug events.
1933
1934        @type  event: L{Event}
1935        @param event: Event object.
1936        """
1937        lpBaseOfDll = event.get_module_base()
1938        hFile       = event.get_file_handle()
1939##        if not self.has_module(lpBaseOfDll):  # XXX this would trigger a scan
1940        if lpBaseOfDll not in self.__moduleDict:
1941            fileName = event.get_filename()
1942            if not fileName:
1943                fileName = None
1944            if hasattr(event, 'get_start_address'):
1945                EntryPoint = event.get_start_address()
1946            else:
1947                EntryPoint = None
1948            aModule  = Module(lpBaseOfDll, hFile, fileName = fileName,
1949                                                EntryPoint = EntryPoint,
1950                                                   process = self)
1951            self._add_module(aModule)
1952        else:
1953            aModule = self.get_module(lpBaseOfDll)
1954            if not aModule.hFile and hFile not in (None, 0,
1955                                                   win32.INVALID_HANDLE_VALUE):
1956                aModule.hFile = hFile
1957            if not aModule.process:
1958                aModule.process = self
1959            if aModule.EntryPoint is None and \
1960                                           hasattr(event, 'get_start_address'):
1961                aModule.EntryPoint = event.get_start_address()
1962            if not aModule.fileName:
1963                fileName = event.get_filename()
1964                if fileName:
1965                    aModule.fileName = fileName
1966
1967    def _notify_create_process(self, event):
1968        """
1969        Notify the load of the main module.
1970
1971        This is done automatically by the L{Debug} class, you shouldn't need
1972        to call it yourself.
1973
1974        @type  event: L{CreateProcessEvent}
1975        @param event: Create process event.
1976
1977        @rtype:  bool
1978        @return: C{True} to call the user-defined handle, C{False} otherwise.
1979        """
1980        self.__add_loaded_module(event)
1981        return True
1982
1983    def _notify_load_dll(self, event):
1984        """
1985        Notify the load of a new module.
1986
1987        This is done automatically by the L{Debug} class, you shouldn't need
1988        to call it yourself.
1989
1990        @type  event: L{LoadDLLEvent}
1991        @param event: Load DLL event.
1992
1993        @rtype:  bool
1994        @return: C{True} to call the user-defined handle, C{False} otherwise.
1995        """
1996        self.__add_loaded_module(event)
1997        return True
1998
1999    def _notify_unload_dll(self, event):
2000        """
2001        Notify the release of a loaded module.
2002
2003        This is done automatically by the L{Debug} class, you shouldn't need
2004        to call it yourself.
2005
2006        @type  event: L{UnloadDLLEvent}
2007        @param event: Unload DLL event.
2008
2009        @rtype:  bool
2010        @return: C{True} to call the user-defined handle, C{False} otherwise.
2011        """
2012        lpBaseOfDll = event.get_module_base()
2013##        if self.has_module(lpBaseOfDll):  # XXX this would trigger a scan
2014        if lpBaseOfDll in self.__moduleDict:
2015            self._del_module(lpBaseOfDll)
2016        return True
2017