1###
2# Copyright (c) 2002-2005, Jeremiah Fincher
3# Copyright (c) 2008, James McCoy
4# All rights reserved.
5#
6# Redistribution and use in source and binary forms, with or without
7# modification, are permitted provided that the following conditions are met:
8#
9#   * Redistributions of source code must retain the above copyright notice,
10#     this list of conditions, and the following disclaimer.
11#   * Redistributions in binary form must reproduce the above copyright notice,
12#     this list of conditions, and the following disclaimer in the
13#     documentation and/or other materials provided with the distribution.
14#   * Neither the name of the author of this software nor the name of
15#     contributors to this software may be used to endorse or promote products
16#     derived from this software without specific prior written consent.
17#
18# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
21# ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
22# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
23# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
24# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
25# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
26# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
27# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
28# POSSIBILITY OF SUCH DAMAGE.
29###
30
31import gc
32import os
33import csv
34import time
35import codecs
36import fnmatch
37import os.path
38import threading
39import collections
40
41from .. import callbacks, conf, dbi, ircdb, ircutils, log, utils, world
42from ..commands import *
43
44class NoSuitableDatabase(Exception):
45    def __init__(self, suitable):
46        self.suitable = list(suitable)
47        self.suitable.sort()
48
49    def __str__(self):
50        return format('No suitable databases were found.  Suitable databases '
51                      'include %L.  If you have one of these databases '
52                      'installed, make sure it is listed in the '
53                      'supybot.databases configuration variable.',
54                      self.suitable)
55
56def DB(filename, types):
57    # We don't care if any of the DBs are actually available when
58    # documenting, so just fake that we found something suitable
59    if world.documenting:
60        def junk(*args, **kwargs):
61            pass
62        return junk
63    def MakeDB(*args, **kwargs):
64        for type in conf.supybot.databases():
65            # Can't do this because Python sucks.  Go ahead, try it!
66            # filename = '.'.join([filename, type, 'db'])
67            fn = '.'.join([filename, type, 'db'])
68            fn = utils.file.sanitizeName(fn)
69            path = conf.supybot.directories.data.dirize(fn)
70            try:
71                return types[type](path, *args, **kwargs)
72            except KeyError:
73                continue
74        raise NoSuitableDatabase(types.keys())
75    return MakeDB
76
77def makeChannelFilename(filename, channel=None, dirname=None):
78    assert channel is not None, 'Death to those who use None for their channel'
79    filename = os.path.basename(filename)
80    channelSpecific = conf.supybot.databases.plugins.channelSpecific
81    channel = channelSpecific.getChannelLink(channel)
82    channel = utils.file.sanitizeName(ircutils.toLower(channel))
83    if dirname is None:
84        dirname = conf.supybot.directories.data.dirize(channel)
85    if not os.path.exists(dirname):
86        os.makedirs(dirname)
87    return os.path.join(dirname, filename)
88
89def getChannel(channel):
90    assert channel is not None, 'Death to those who use None for their channel'
91    channelSpecific = conf.supybot.databases.plugins.channelSpecific
92    return channelSpecific.getChannelLink(channel)
93
94# XXX This shouldn't be a mixin.  This should be contained by classes that
95#     want such behavior.  But at this point, it wouldn't gain much for us
96#     to refactor it.
97# XXX We need to get rid of this, it's ugly and opposed to
98#     database-independence.
99class ChannelDBHandler(object):
100    """A class to handle database stuff for individual channels transparently.
101    """
102    suffix = '.db'
103    def __init__(self, suffix='.db'):
104        self.dbCache = ircutils.IrcDict()
105        suffix = self.suffix
106        if self.suffix and self.suffix[0] != '.':
107            suffix = '.' + suffix
108        self.suffix = suffix
109
110    def makeFilename(self, channel):
111        """Override this to specialize the filenames of your databases."""
112        channel = ircutils.toLower(channel)
113        className = self.__class__.__name__
114        return makeChannelFilename(className + self.suffix, channel)
115
116    def makeDb(self, filename):
117        """Override this to create your databases."""
118        raise NotImplementedError
119
120    def getDb(self, channel):
121        """Use this to get a database for a specific channel."""
122        currentThread = threading.currentThread()
123        if channel not in self.dbCache and currentThread == world.mainThread:
124            self.dbCache[channel] = self.makeDb(self.makeFilename(channel))
125        if currentThread != world.mainThread:
126            db = self.makeDb(self.makeFilename(channel))
127        else:
128            db = self.dbCache[channel]
129        db.isolation_level = None
130        return db
131
132    def die(self):
133        for db in self.dbCache.values():
134            try:
135                db.commit()
136            except AttributeError: # In case it's not an SQLite database.
137                pass
138            try:
139                db.close()
140            except AttributeError: # In case it doesn't have a close method.
141                pass
142            del db
143        gc.collect()
144
145
146class DbiChannelDB(object):
147    """This just handles some of the general stuff for Channel DBI databases.
148    Check out ChannelIdDatabasePlugin for an example of how to use this."""
149    def __init__(self, filename):
150        self.filename = filename
151        self.dbs = ircutils.IrcDict()
152
153    def _getDb(self, channel):
154        filename = makeChannelFilename(self.filename, channel)
155        try:
156            db = self.dbs[channel]
157        except KeyError:
158            db = self.DB(filename)
159            self.dbs[channel] = db
160        return db
161
162    def close(self):
163        for db in self.dbs.values():
164            db.close()
165
166    def flush(self):
167        for db in self.dbs.values():
168            db.flush()
169
170    def __getattr__(self, attr):
171        def _getDbAndDispatcher(channel, *args, **kwargs):
172            db = self._getDb(channel)
173            return getattr(db, attr)(*args, **kwargs)
174        return _getDbAndDispatcher
175
176
177class ChannelUserDictionary(collections.MutableMapping):
178    IdDict = dict
179    def __init__(self):
180        self.channels = ircutils.IrcDict()
181
182    def __getitem__(self, key):
183        (channel, id) = key
184        return self.channels[channel][id]
185
186    def __setitem__(self, key, v):
187        (channel, id) = key
188        if channel not in self.channels:
189            self.channels[channel] = self.IdDict()
190        self.channels[channel][id] = v
191
192    def __delitem__(self, key):
193        (channel, id) = key
194        del self.channels[channel][id]
195
196    def __iter__(self):
197        for channel, ids in self.channels.items():
198            for id_, value in ids.items():
199                yield (channel, id_)
200
201    def __len__(self):
202        return sum([len(x) for x in self.channels])
203
204    def items(self):
205        for (channel, ids) in self.channels.items():
206            for (id, v) in ids.items():
207                yield ((channel, id), v)
208
209    def keys(self):
210        L = []
211        for (k, _) in self.items():
212            L.append(k)
213        return L
214
215
216# XXX The interface to this needs to be made *much* more like the dbi.DB
217#     interface.  This is just too odd and not extensible; any extension
218#     would very much feel like an extension, rather than part of the db
219#     itself.
220class ChannelUserDB(ChannelUserDictionary):
221    def __init__(self, filename):
222        ChannelUserDictionary.__init__(self)
223        self.filename = filename
224        try:
225            fd = codecs.open(self.filename, encoding='utf8')
226        except EnvironmentError as e:
227            log.warning('Couldn\'t open %s: %s.', self.filename, e)
228            return
229        reader = csv.reader(fd)
230        try:
231            lineno = 0
232            for t in reader:
233                lineno += 1
234                try:
235                    channel = t.pop(0)
236                    id = t.pop(0)
237                    try:
238                        id = int(id)
239                    except ValueError:
240                        # We'll skip over this so, say, nicks can be kept here.
241                        pass
242                    v = self.deserialize(channel, id, t)
243                    self[channel, id] = v
244                except Exception as e:
245                    log.warning('Invalid line #%s in %s.',
246                                lineno, self.__class__.__name__)
247                    log.debug('Exception: %s', utils.exnToString(e))
248        except Exception as e: # This catches exceptions from csv.reader.
249            log.warning('Invalid line #%s in %s.',
250                        lineno, self.__class__.__name__)
251            log.debug('Exception: %s', utils.exnToString(e))
252
253    def flush(self):
254        mode = 'wb' if utils.minisix.PY2 else 'w'
255        fd = utils.file.AtomicFile(self.filename, mode, makeBackupIfSmaller=False)
256        writer = csv.writer(fd)
257        items = list(self.items())
258        if not items:
259            log.debug('%s: Refusing to write blank file.',
260                      self.__class__.__name__)
261            fd.rollback()
262            return
263        try:
264            items.sort()
265        except TypeError:
266            # FIXME: Implement an algorithm that can order dictionnaries
267            # with both strings and integers as keys.
268            pass
269        for ((channel, id), v) in items:
270            L = self.serialize(v)
271            L.insert(0, id)
272            L.insert(0, channel)
273            writer.writerow(L)
274        fd.close()
275
276    def close(self):
277        self.flush()
278        self.clear()
279
280    def deserialize(self, channel, id, L):
281        """Should take a list of strings and return an object to be accessed
282        via self.get(channel, id)."""
283        raise NotImplementedError
284
285    def serialize(self, x):
286        """Should take an object (as returned by self.get(channel, id)) and
287        return a list (of any type serializable to csv)."""
288        raise NotImplementedError
289
290
291def getUserName(id):
292    if isinstance(id, int):
293        try:
294            return ircdb.users.getUser(id).name
295        except KeyError:
296            return 'a user that is no longer registered'
297    else:
298        return id
299
300class ChannelIdDatabasePlugin(callbacks.Plugin):
301    class DB(DbiChannelDB):
302        class DB(dbi.DB):
303            class Record(dbi.Record):
304                __fields__ = [
305                    'at',
306                    'by',
307                    'text'
308                    ]
309            def add(self, at, by, text, **kwargs):
310                record = self.Record(at=at, by=by, text=text, **kwargs)
311                return super(self.__class__, self).add(record)
312
313    def __init__(self, irc):
314        self.__parent = super(ChannelIdDatabasePlugin, self)
315        self.__parent.__init__(irc)
316        self.db = DB(self.name(), {'flat': self.DB})()
317
318    def die(self):
319        self.db.close()
320        self.__parent.die()
321
322    def getCommandHelp(self, name, simpleSyntax=None):
323        help = self.__parent.getCommandHelp(name, simpleSyntax)
324        help = help.replace('$Types', format('%p', self.name()))
325        help = help.replace('$Type', self.name())
326        help = help.replace('$types', format('%p', self.name().lower()))
327        help = help.replace('$type', self.name().lower())
328        return help
329
330    def noSuchRecord(self, irc, channel, id):
331        irc.error('There is no %s with id #%s in my database for %s.' %
332                  (self.name(), id, channel))
333
334    def checkChangeAllowed(self, irc, msg, channel, user, record):
335        # Checks and returns True if either the user ID (integer)
336        # or the hostmask of the caller match.
337        if (hasattr(user, 'id') and user.id == record.by) or user == record.by:
338            return True
339        cap = ircdb.makeChannelCapability(channel, 'op')
340        if ircdb.checkCapability(msg.prefix, cap):
341            return True
342        irc.errorNoCapability(cap)
343
344    def addValidator(self, irc, text):
345        """This should irc.error or raise an exception if text is invalid."""
346        pass
347
348    def getUserId(self, irc, prefix, channel=None):
349        try:
350            user = ircdb.users.getUser(prefix)
351            return user.id
352        except KeyError:
353            if conf.get(conf.supybot.databases.plugins.requireRegistration, channel):
354                irc.errorNotRegistered(Raise=True)
355            return
356
357    def add(self, irc, msg, args, channel, text):
358        """[<channel>] <text>
359
360        Adds <text> to the $type database for <channel>.
361        <channel> is only necessary if the message isn't sent in the channel
362        itself.
363        """
364        user = self.getUserId(irc, msg.prefix, channel) or msg.prefix
365        at = time.time()
366        self.addValidator(irc, text)
367        if text is not None:
368            id = self.db.add(channel, at, user, text)
369            irc.replySuccess('%s #%s added.' % (self.name(), id))
370    add = wrap(add, ['channeldb', 'text'])
371
372    def remove(self, irc, msg, args, channel, id):
373        """[<channel>] <id>
374
375        Removes the $type with id <id> from the $type database for <channel>.
376        <channel> is only necessary if the message isn't sent in the channel
377        itself.
378        """
379        user = self.getUserId(irc, msg.prefix, channel) or msg.prefix
380        try:
381            record = self.db.get(channel, id)
382            self.checkChangeAllowed(irc, msg, channel, user, record)
383            self.db.remove(channel, id)
384            irc.replySuccess()
385        except KeyError:
386            self.noSuchRecord(irc, channel, id)
387    remove = wrap(remove, ['channeldb', 'id'])
388
389    def searchSerializeRecord(self, record):
390        text = utils.str.ellipsisify(record.text, 50)
391        return format('#%s: %q', record.id, text)
392
393    def search(self, irc, msg, args, channel, optlist, glob):
394        """[<channel>] [--{regexp,by} <value>] [<glob>]
395
396        Searches for $types matching the criteria given.
397        """
398        predicates = []
399        def p(record):
400            for predicate in predicates:
401                if not predicate(record):
402                    return False
403            return True
404
405        for (opt, arg) in optlist:
406            if opt == 'by':
407                predicates.append(lambda r, arg=arg: r.by == arg.id)
408            elif opt == 'regexp':
409                if not ircdb.checkCapability(msg.prefix, 'trusted'):
410                    # Limited --regexp to trusted users, because specially
411                    # crafted regexps can freeze the bot. See
412                    # https://github.com/ProgVal/Limnoria/issues/855 for details
413                    irc.errorNoCapability('trusted')
414
415                predicates.append(lambda r: regexp_wrapper(r.text, reobj=arg,
416                        timeout=0.1, plugin_name=self.name(), fcn_name='search'))
417        if glob:
418            def globP(r, glob=glob.lower()):
419                return fnmatch.fnmatch(r.text.lower(), glob)
420            predicates.append(globP)
421        L = []
422        for record in self.db.select(channel, p):
423            L.append(self.searchSerializeRecord(record))
424        if L:
425            L.sort()
426            irc.reply(format('%s found: %L', len(L), L))
427        else:
428            what = self.name().lower()
429            irc.reply(format('No matching %p were found.', what))
430    search = wrap(search, ['channeldb',
431                           getopts({'by': 'otherUser',
432                                    'regexp': 'regexpMatcher'}),
433                           additional(rest('glob'))])
434
435    def showRecord(self, record):
436        name = getUserName(record.by)
437        return format('%s #%s: %q (added by %s at %t)',
438                      self.name(), record.id, record.text, name, record.at)
439
440    def get(self, irc, msg, args, channel, id):
441        """[<channel>] <id>
442
443        Gets the $type with id <id> from the $type database for <channel>.
444        <channel> is only necessary if the message isn't sent in the channel
445        itself.
446        """
447        try:
448            record = self.db.get(channel, id)
449            irc.reply(self.showRecord(record))
450        except KeyError:
451            self.noSuchRecord(irc, channel, id)
452    get = wrap(get, ['channeldb', 'id'])
453
454    def change(self, irc, msg, args, channel, id, replacer):
455        """[<channel>] <id> <regexp>
456
457        Changes the $type with id <id> according to the regular expression
458        <regexp>.  <channel> is only necessary if the message isn't sent in the
459        channel itself.
460        """
461        user = self.getUserId(irc, msg.prefix, channel) or msg.prefix
462        try:
463            record = self.db.get(channel, id)
464            self.checkChangeAllowed(irc, msg, channel, user, record)
465            record.text = replacer(record.text)
466            self.db.set(channel, id, record)
467            irc.replySuccess()
468        except KeyError:
469            self.noSuchRecord(irc, channel, id)
470    change = wrap(change, ['channeldb', 'id', 'regexpReplacer'])
471
472    def stats(self, irc, msg, args, channel):
473        """[<channel>]
474
475        Returns the number of $types in the database for <channel>.
476        <channel> is only necessary if the message isn't sent in the channel
477        itself.
478        """
479        n = self.db.size(channel)
480        whats = self.name().lower()
481        irc.reply(format('There %b %n in my database.', n, (n, whats)))
482    stats = wrap(stats, ['channeldb'])
483
484
485class PeriodicFileDownloader(object):
486    """A class to periodically download a file/files.
487
488    A class-level dictionary 'periodicFiles' maps names of files to
489    three-tuples of
490    (url, seconds between downloads, function to run with downloaded file).
491
492    'url' should be in some form that urllib2.urlopen can handle (do note that
493    urllib2.urlopen handles file:// links perfectly well.)
494
495    'seconds between downloads' is the number of seconds between downloads,
496    obviously.  An important point to remember, however, is that it is only
497    engaged when a command is run.  I.e., if you say you want the file
498    downloaded every day, but no commands that use it are run in a week, the
499    next time such a command is run, it'll be using a week-old file.  If you
500    don't want such behavior, you'll have to give an error mess age to the user
501    and tell them to call you back in the morning.
502
503    'function to run with downloaded file' is a function that will be passed
504    a string *filename* of the downloaded file.  This will be some random
505    filename probably generated via some mktemp-type-thing.  You can do what
506    you want with this; you may want to build a database, take some stats,
507    or simply rename the file.  You can pass None as your function and the
508    file with automatically be renamed to match the filename you have it listed
509    under.  It'll be in conf.supybot.directories.data, of course.
510
511    Aside from that dictionary, simply use self.getFile(filename) in any method
512    that makes use of a periodically downloaded file, and you'll be set.
513    """
514    periodicFiles = None
515    def __init__(self):
516        if self.periodicFiles is None:
517            raise ValueError('You must provide files to download')
518        self.lastDownloaded = {}
519        self.downloadedCounter = {}
520        for filename in self.periodicFiles:
521            if self.periodicFiles[filename][-1] is None:
522                fullname = os.path.join(conf.supybot.directories.data(),
523                                        filename)
524                if os.path.exists(fullname):
525                    self.lastDownloaded[filename] = os.stat(fullname).st_ctime
526                else:
527                    self.lastDownloaded[filename] = 0
528            else:
529                self.lastDownloaded[filename] = 0
530            self.currentlyDownloading = set()
531            self.downloadedCounter[filename] = 0
532            self.getFile(filename)
533
534    def _downloadFile(self, filename, url, f):
535        self.currentlyDownloading.add(filename)
536        try:
537            try:
538                infd = utils.web.getUrlFd(url)
539            except IOError as e:
540                self.log.warning('Error downloading %s: %s', url, e)
541                return
542            except utils.web.Error as e:
543                self.log.warning('Error downloading %s: %s', url, e)
544                return
545            confDir = conf.supybot.directories.data()
546            newFilename = os.path.join(confDir, utils.file.mktemp())
547            outfd = open(newFilename, 'wb')
548            start = time.time()
549            s = infd.read(4096)
550            while s:
551                outfd.write(s)
552                s = infd.read(4096)
553            infd.close()
554            outfd.close()
555            self.log.info('Downloaded %s in %s seconds',
556                          filename, time.time()-start)
557            self.downloadedCounter[filename] += 1
558            self.lastDownloaded[filename] = time.time()
559            if f is None:
560                toFilename = os.path.join(confDir, filename)
561                if os.name == 'nt':
562                    # Windows, grrr...
563                    if os.path.exists(toFilename):
564                        os.remove(toFilename)
565                os.rename(newFilename, toFilename)
566            else:
567                start = time.time()
568                f(newFilename)
569                total = time.time() - start
570                self.log.info('Function ran on %s in %s seconds',
571                              filename, total)
572        finally:
573            self.currentlyDownloading.remove(filename)
574
575    def getFile(self, filename):
576        if world.documenting:
577            return
578        (url, timeLimit, f) = self.periodicFiles[filename]
579        if time.time() - self.lastDownloaded[filename] > timeLimit and \
580           filename not in self.currentlyDownloading:
581            self.log.info('Beginning download of %s', url)
582            args = (filename, url, f)
583            name = '%s #%s' % (filename, self.downloadedCounter[filename])
584            t = threading.Thread(target=self._downloadFile, name=name,
585                                 args=(filename, url, f))
586            t.setDaemon(True)
587            t.start()
588            world.threadsSpawned += 1
589
590
591
592
593# vim:set shiftwidth=4 softtabstop=4 expandtab textwidth=79:
594