1# (C) Copyright 2007-2019 Enthought, Inc., Austin, TX
2# All rights reserved.
3#
4# This software is provided without warranty under the terms of the BSD
5# license included in LICENSE.txt and may be redistributed only
6# under the conditions described in the aforementioned license.  The license
7# is also available online at http://www.enthought.com/licenses/BSD.txt
8# Thanks for using Enthought open source!
9
10import os
11import shutil
12import tempfile
13
14
15class ETSConfigPatcher(object):
16    """
17    Object that patches the directories in ETSConfig, to avoid having
18    tests write to the home directory.
19    """
20    def __init__(self):
21        from traits.etsconfig.api import ETSConfig
22        self.etsconfig = ETSConfig
23
24        self.tmpdir = None
25        self.old_application_data = None
26        self.old_application_home = None
27        self.old_user_data = None
28
29    def start(self):
30        tmpdir = self.tmpdir = tempfile.mkdtemp()
31
32        self.old_application_data = self.etsconfig._application_data
33        self.etsconfig._application_data = os.path.join(
34            tmpdir, "application_data")
35
36        self.old_application_home = self.etsconfig._application_home
37        self.etsconfig._application_home = os.path.join(
38            tmpdir, "application_home")
39
40        self.old_user_data = self.etsconfig._user_data
41        self.etsconfig._user_data = os.path.join(tmpdir, "user_home")
42
43    def stop(self):
44        if self.old_user_data is not None:
45            self.etsconfig._user_data = self.old_user_data
46            self.old_user_data = None
47
48        if self.old_application_home is not None:
49            self.etsconfig._application_home = self.old_application_home
50            self.old_application_home = None
51
52        if self.old_application_data is not None:
53            self.etsconfig._application_data = self.old_application_data
54            self.old_application_data = None
55
56        if self.tmpdir is not None:
57            shutil.rmtree(self.tmpdir)
58            self.tmpdir = None
59