1#!/usr/bin/env python
2
3"""
4Copyright (c) 2006-2019 sqlmap developers (http://sqlmap.org/)
5See the file 'LICENSE' for copying permission
6"""
7
8from __future__ import print_function
9
10import difflib
11import threading
12import time
13import traceback
14
15from lib.core.compat import WichmannHill
16from lib.core.compat import xrange
17from lib.core.data import conf
18from lib.core.data import kb
19from lib.core.data import logger
20from lib.core.datatype import AttribDict
21from lib.core.enums import PAYLOAD
22from lib.core.exception import SqlmapBaseException
23from lib.core.exception import SqlmapConnectionException
24from lib.core.exception import SqlmapThreadException
25from lib.core.exception import SqlmapUserQuitException
26from lib.core.exception import SqlmapValueException
27from lib.core.settings import MAX_NUMBER_OF_THREADS
28from lib.core.settings import PYVERSION
29
30shared = AttribDict()
31
32class _ThreadData(threading.local):
33    """
34    Represents thread independent data
35    """
36
37    def __init__(self):
38        self.reset()
39
40    def reset(self):
41        """
42        Resets thread data model
43        """
44
45        self.disableStdOut = False
46        self.hashDBCursor = None
47        self.inTransaction = False
48        self.lastCode = None
49        self.lastComparisonPage = None
50        self.lastComparisonHeaders = None
51        self.lastComparisonCode = None
52        self.lastComparisonRatio = None
53        self.lastErrorPage = tuple()
54        self.lastHTTPError = None
55        self.lastRedirectMsg = None
56        self.lastQueryDuration = 0
57        self.lastPage = None
58        self.lastRequestMsg = None
59        self.lastRequestUID = 0
60        self.lastRedirectURL = tuple()
61        self.random = WichmannHill()
62        self.resumed = False
63        self.retriesCount = 0
64        self.seqMatcher = difflib.SequenceMatcher(None)
65        self.shared = shared
66        self.technique = None
67        self.validationRun = 0
68        self.valueStack = []
69
70ThreadData = _ThreadData()
71
72def readInput(message, default=None, checkBatch=True, boolean=False):
73    # It will be overwritten by original from lib.core.common
74    pass
75
76def isDigit(value):
77    # It will be overwritten by original from lib.core.common
78    pass
79
80def getCurrentThreadData():
81    """
82    Returns current thread's local data
83    """
84
85    return ThreadData
86
87def getCurrentThreadName():
88    """
89    Returns current's thread name
90    """
91
92    return threading.current_thread().getName()
93
94def exceptionHandledFunction(threadFunction, silent=False):
95    try:
96        threadFunction()
97    except KeyboardInterrupt:
98        kb.threadContinue = False
99        kb.threadException = True
100        raise
101    except Exception as ex:
102        from lib.core.common import getSafeExString
103
104        if not silent and kb.get("threadContinue") and not isinstance(ex, SqlmapUserQuitException):
105            errMsg = getSafeExString(ex) if isinstance(ex, SqlmapBaseException) else "%s: %s" % (type(ex).__name__, getSafeExString(ex))
106            logger.error("thread %s: '%s'" % (threading.currentThread().getName(), errMsg))
107
108            if conf.get("verbose") > 1 and not isinstance(ex, SqlmapConnectionException):
109                traceback.print_exc()
110
111def setDaemon(thread):
112    # Reference: http://stackoverflow.com/questions/190010/daemon-threads-explanation
113    if PYVERSION >= "2.6":
114        thread.daemon = True
115    else:
116        thread.setDaemon(True)
117
118def runThreads(numThreads, threadFunction, cleanupFunction=None, forwardException=True, threadChoice=False, startThreadMsg=True):
119    threads = []
120
121    kb.multipleCtrlC = False
122    kb.threadContinue = True
123    kb.threadException = False
124    kb.technique = ThreadData.technique
125
126    if threadChoice and conf.threads == numThreads == 1 and not (kb.injection.data and not any(_ not in (PAYLOAD.TECHNIQUE.TIME, PAYLOAD.TECHNIQUE.STACKED) for _ in kb.injection.data)):
127        while True:
128            message = "please enter number of threads? [Enter for %d (current)] " % numThreads
129            choice = readInput(message, default=str(numThreads))
130            if choice:
131                skipThreadCheck = False
132
133                if choice.endswith('!'):
134                    choice = choice[:-1]
135                    skipThreadCheck = True
136
137                if isDigit(choice):
138                    if int(choice) > MAX_NUMBER_OF_THREADS and not skipThreadCheck:
139                        errMsg = "maximum number of used threads is %d avoiding potential connection issues" % MAX_NUMBER_OF_THREADS
140                        logger.critical(errMsg)
141                    else:
142                        conf.threads = numThreads = int(choice)
143                        break
144
145        if numThreads == 1:
146            warnMsg = "running in a single-thread mode. This could take a while"
147            logger.warn(warnMsg)
148
149    try:
150        if numThreads > 1:
151            if startThreadMsg:
152                infoMsg = "starting %d threads" % numThreads
153                logger.info(infoMsg)
154        else:
155            threadFunction()
156            return
157
158        # Start the threads
159        for numThread in xrange(numThreads):
160            thread = threading.Thread(target=exceptionHandledFunction, name=str(numThread), args=[threadFunction])
161
162            setDaemon(thread)
163
164            try:
165                thread.start()
166            except Exception as ex:
167                errMsg = "error occurred while starting new thread ('%s')" % ex
168                logger.critical(errMsg)
169                break
170
171            threads.append(thread)
172
173        # And wait for them to all finish
174        alive = True
175        while alive:
176            alive = False
177            for thread in threads:
178                if thread.isAlive():
179                    alive = True
180                    time.sleep(0.1)
181
182    except (KeyboardInterrupt, SqlmapUserQuitException) as ex:
183        print()
184        kb.prependFlag = False
185        kb.threadContinue = False
186        kb.threadException = True
187
188        if kb.lastCtrlCTime and (time.time() - kb.lastCtrlCTime < 1):
189            kb.multipleCtrlC = True
190            raise SqlmapUserQuitException("user aborted (Ctrl+C was pressed multiple times)")
191
192        kb.lastCtrlCTime = time.time()
193
194        if numThreads > 1:
195            logger.info("waiting for threads to finish%s" % (" (Ctrl+C was pressed)" if isinstance(ex, KeyboardInterrupt) else ""))
196        try:
197            while (threading.activeCount() > 1):
198                pass
199
200        except KeyboardInterrupt:
201            kb.multipleCtrlC = True
202            raise SqlmapThreadException("user aborted (Ctrl+C was pressed multiple times)")
203
204        if forwardException:
205            raise
206
207    except (SqlmapConnectionException, SqlmapValueException) as ex:
208        print()
209        kb.threadException = True
210        logger.error("thread %s: '%s'" % (threading.currentThread().getName(), ex))
211
212        if conf.get("verbose") > 1 and isinstance(ex, SqlmapValueException):
213            traceback.print_exc()
214
215    except:
216        print()
217
218        if not kb.multipleCtrlC:
219            from lib.core.common import unhandledExceptionMessage
220
221            kb.threadException = True
222            errMsg = unhandledExceptionMessage()
223            logger.error("thread %s: %s" % (threading.currentThread().getName(), errMsg))
224            traceback.print_exc()
225
226    finally:
227        kb.threadContinue = True
228        kb.threadException = False
229        kb.technique = None
230
231        for lock in kb.locks.values():
232            if lock.locked():
233                try:
234                    lock.release()
235                except:
236                    pass
237
238        if conf.get("hashDB"):
239            conf.hashDB.flush(True)
240
241        if cleanupFunction:
242            cleanupFunction()
243