1#!/usr/local/bin/python3.8
2
3import os
4import sys
5import re
6import configparser
7import getpass
8from collections import namedtuple
9
10try:
11    import pymysql
12    HAS_PYMYSQL = True
13except ImportError:
14    HAS_PYMYSQL = False
15
16MySQLDBInfo = namedtuple('MySQLDBInfo', 'host port username password db')
17
18class EnvManager(object):
19    def __init__(self):
20        self.upgrade_dir = os.path.abspath(os.path.dirname(__file__))
21        self.install_path = os.path.dirname(self.upgrade_dir)
22        self.top_dir = os.path.dirname(self.install_path)
23        self.ccnet_dir = os.environ['CCNET_CONF_DIR']
24        self.seafile_dir = os.environ['SEAFILE_CONF_DIR']
25
26env_mgr = EnvManager()
27
28class Utils(object):
29    @staticmethod
30    def highlight(content, is_error=False):
31        '''Add ANSI color to content to get it highlighted on terminal'''
32        if is_error:
33            return '\x1b[1;31m%s\x1b[m' % content
34        else:
35            return '\x1b[1;32m%s\x1b[m' % content
36
37    @staticmethod
38    def info(msg):
39        print(Utils.highlight('[INFO] ') + msg)
40
41    @staticmethod
42    def error(msg):
43        print(Utils.highlight('[ERROR] ') + msg)
44        sys.exit(1)
45
46    @staticmethod
47    def read_config(config_path, defaults):
48        cp = configparser.ConfigParser(defaults)
49        cp.read(config_path)
50        return cp
51
52def get_ccnet_mysql_info():
53    ccnet_conf = os.path.join(env_mgr.ccnet_dir, 'ccnet.conf')
54    defaults = {
55        'HOST': '127.0.0.1',
56        'PORT': '3306',
57    }
58
59    config = Utils.read_config(ccnet_conf, defaults)
60    db_section = 'Database'
61
62    if not config.has_section(db_section):
63        return None
64
65    type = config.get(db_section, 'ENGINE')
66    if type != 'mysql':
67        return None
68
69    try:
70        host = config.get(db_section, 'HOST')
71        port = config.getint(db_section, 'PORT')
72        username = config.get(db_section, 'USER')
73        password = config.get(db_section, 'PASSWD')
74        db = config.get(db_section, 'DB')
75    except configparser.NoOptionError as e:
76        Utils.error('Database config in ccnet.conf is invalid: %s' % e)
77
78    info = MySQLDBInfo(host, port, username, password, db)
79    return info
80
81def get_seafile_mysql_info():
82    seafile_conf = os.path.join(env_mgr.seafile_dir, 'seafile.conf')
83    defaults = {
84        'HOST': '127.0.0.1',
85        'PORT': '3306',
86    }
87    config = Utils.read_config(seafile_conf, defaults)
88    db_section = 'database'
89
90    if not config.has_section(db_section):
91        return None
92
93    type = config.get(db_section, 'type')
94    if type != 'mysql':
95        return None
96
97    try:
98        host = config.get(db_section, 'host')
99        port = config.getint(db_section, 'port')
100        username = config.get(db_section, 'user')
101        password = config.get(db_section, 'password')
102        db = config.get(db_section, 'db_name')
103    except configparser.NoOptionError as e:
104        Utils.error('Database config in seafile.conf is invalid: %s' % e)
105
106    info = MySQLDBInfo(host, port, username, password, db)
107    return info
108
109def get_seahub_mysql_info():
110    sys.path.insert(0, env_mgr.top_dir)
111    try:
112        import seahub_settings# pylint: disable=F0401
113    except ImportError as e:
114        Utils.error('Failed to import seahub_settings.py: %s' % e)
115
116    if not hasattr(seahub_settings, 'DATABASES'):
117        return None
118
119    try:
120        d = seahub_settings.DATABASES['default']
121        if d['ENGINE'] != 'django.db.backends.mysql':
122            return None
123
124        host = d.get('HOST', '127.0.0.1')
125        port = int(d.get('PORT', 3306))
126        username = d['USER']
127        password = d['PASSWORD']
128        db = d['NAME']
129    except KeyError:
130        Utils.error('Database config in seahub_settings.py is invalid: %s' % e)
131
132    info = MySQLDBInfo(host, port, username, password, db)
133    return info
134
135def get_seafile_db_infos():
136    ccnet_db_info = get_ccnet_mysql_info()
137    seafile_db_info = get_seafile_mysql_info()
138    seahub_db_info = get_seahub_mysql_info()
139
140    infos = [ccnet_db_info, seafile_db_info, seahub_db_info]
141
142    for info in infos:
143        if info is None:
144            return None
145        if info.host not in ('localhost', '127.0.0.1'):
146            return None
147    return infos
148
149def ask_root_password(port):
150    while True:
151        desc = 'What is the root password for mysql? '
152        password = getpass.getpass(desc).strip()
153        if password:
154            try:
155                return check_mysql_user('root', password, port)
156            except InvalidAnswer as e:
157                print('\n%s\n' % e)
158                continue
159
160class InvalidAnswer(Exception):
161    def __init__(self, msg):
162        Exception.__init__(self)
163        self.msg = msg
164
165    def __str__(self):
166        return self.msg
167
168def check_mysql_user(user, password, port):
169    print('\nverifying password of root user %s ... ' % user, end=' ')
170    kwargs = dict(host='localhost',
171                  port=port,
172                  user=user,
173                  passwd=password)
174
175    try:
176        conn = pymysql.connect(**kwargs)
177    except Exception as e:
178        if isinstance(e, pymysql.err.OperationalError):
179            raise InvalidAnswer('Failed to connect to mysql server using user "%s" and password "***": %s'
180                                % (user, e.args[1]))
181        else:
182            raise InvalidAnswer('Failed to connect to mysql server using user "%s" and password "***": %s'
183                                % (user, e))
184
185    print('done')
186    return conn
187
188def apply_fix(root_conn, user, dbs):
189    for db in dbs:
190        grant_db_permission(root_conn, user, db)
191
192    cursor = root_conn.cursor()
193    sql = """
194    SELECT *
195    FROM mysql.user
196    WHERE Host = '%%'
197      AND password = ''
198      AND User = '%s'
199    """ % user
200    cursor.execute(sql)
201    if cursor.rowcount > 0:
202        sql = 'DROP USER `%s`@`%%`' % user
203        cursor.execute(sql)
204
205def grant_db_permission(conn, user, db):
206    cursor = conn.cursor()
207    sql = '''GRANT ALL PRIVILEGES ON `%s`.* to `%s`@localhost ''' \
208          % (db, user)
209
210    try:
211        cursor.execute(sql)
212    except Exception as e:
213        if isinstance(e, pymysql.err.OperationalError):
214            Utils.error('Failed to grant permission of database %s: %s' % (db, e.args[1]))
215        else:
216            Utils.error('Failed to grant permission of database %s: %s' % (db, e))
217
218    finally:
219        cursor.close()
220
221def main():
222    dbinfos = get_seafile_db_infos()
223    if not dbinfos:
224        return
225    if dbinfos[0].username == 'root':
226        return
227
228    if not HAS_PYMYSQL:
229        Utils.error('Python pymysql module is not found')
230    root_conn = ask_root_password(dbinfos[0].port)
231    apply_fix(root_conn, dbinfos[0].username, [info.db for info in dbinfos])
232
233if __name__ == '__main__':
234    main()
235