1"""Interfaces for accessing metadata.
2
3We provide two implementations.
4 * The "classic" file system implementation, which uses a directory
5   structure of files.
6 * A hokey sqlite backed implementation, which basically simulates
7   the file system in an effort to work around poor file system performance
8   on OS X.
9"""
10
11import binascii
12import os
13import time
14
15from abc import abstractmethod
16from typing import List, Iterable, Any, Optional
17from typing_extensions import TYPE_CHECKING
18if TYPE_CHECKING:
19    # We avoid importing sqlite3 unless we are using it so we can mostly work
20    # on semi-broken pythons that are missing it.
21    import sqlite3
22
23
24class MetadataStore:
25    """Generic interface for metadata storage."""
26
27    @abstractmethod
28    def getmtime(self, name: str) -> float:
29        """Read the mtime of a metadata entry..
30
31        Raises FileNotFound if the entry does not exist.
32        """
33        pass
34
35    @abstractmethod
36    def read(self, name: str) -> str:
37        """Read the contents of a metadata entry.
38
39        Raises FileNotFound if the entry does not exist.
40        """
41        pass
42
43    @abstractmethod
44    def write(self, name: str, data: str, mtime: Optional[float] = None) -> bool:
45        """Write a metadata entry.
46
47        If mtime is specified, set it as the mtime of the entry. Otherwise,
48        the current time is used.
49
50        Returns True if the entry is successfully written, False otherwise.
51        """
52
53    @abstractmethod
54    def remove(self, name: str) -> None:
55        """Delete a metadata entry"""
56        pass
57
58    @abstractmethod
59    def commit(self) -> None:
60        """If the backing store requires a commit, do it.
61
62        But N.B. that this is not *guaranteed* to do anything, and
63        there is no guarantee that changes are not made until it is
64        called.
65        """
66        pass
67
68    @abstractmethod
69    def list_all(self) -> Iterable[str]: ...
70
71
72def random_string() -> str:
73    return binascii.hexlify(os.urandom(8)).decode('ascii')
74
75
76class FilesystemMetadataStore(MetadataStore):
77    def __init__(self, cache_dir_prefix: str) -> None:
78        # We check startswith instead of equality because the version
79        # will have already been appended by the time the cache dir is
80        # passed here.
81        if cache_dir_prefix.startswith(os.devnull):
82            self.cache_dir_prefix = None
83        else:
84            self.cache_dir_prefix = cache_dir_prefix
85
86    def getmtime(self, name: str) -> float:
87        if not self.cache_dir_prefix:
88            raise FileNotFoundError()
89
90        return int(os.path.getmtime(os.path.join(self.cache_dir_prefix, name)))
91
92    def read(self, name: str) -> str:
93        assert os.path.normpath(name) != os.path.abspath(name), "Don't use absolute paths!"
94
95        if not self.cache_dir_prefix:
96            raise FileNotFoundError()
97
98        with open(os.path.join(self.cache_dir_prefix, name), 'r') as f:
99            return f.read()
100
101    def write(self, name: str, data: str, mtime: Optional[float] = None) -> bool:
102        assert os.path.normpath(name) != os.path.abspath(name), "Don't use absolute paths!"
103
104        if not self.cache_dir_prefix:
105            return False
106
107        path = os.path.join(self.cache_dir_prefix, name)
108        tmp_filename = path + '.' + random_string()
109        try:
110            os.makedirs(os.path.dirname(path), exist_ok=True)
111            with open(tmp_filename, 'w') as f:
112                f.write(data)
113            os.replace(tmp_filename, path)
114            if mtime is not None:
115                os.utime(path, times=(mtime, mtime))
116
117        except os.error:
118            return False
119        return True
120
121    def remove(self, name: str) -> None:
122        if not self.cache_dir_prefix:
123            raise FileNotFoundError()
124
125        os.remove(os.path.join(self.cache_dir_prefix, name))
126
127    def commit(self) -> None:
128        pass
129
130    def list_all(self) -> Iterable[str]:
131        if not self.cache_dir_prefix:
132            return
133
134        for dir, _, files in os.walk(self.cache_dir_prefix):
135            dir = os.path.relpath(dir, self.cache_dir_prefix)
136            for file in files:
137                yield os.path.join(dir, file)
138
139
140SCHEMA = '''
141CREATE TABLE IF NOT EXISTS files (
142    path TEXT UNIQUE NOT NULL,
143    mtime REAL,
144    data TEXT
145);
146CREATE INDEX IF NOT EXISTS path_idx on files(path);
147'''
148# No migrations yet
149MIGRATIONS = [
150]  # type: List[str]
151
152
153def connect_db(db_file: str) -> 'sqlite3.Connection':
154    import sqlite3.dbapi2
155
156    db = sqlite3.dbapi2.connect(db_file)
157    db.executescript(SCHEMA)
158    for migr in MIGRATIONS:
159        try:
160            db.executescript(migr)
161        except sqlite3.OperationalError:
162            pass
163    return db
164
165
166class SqliteMetadataStore(MetadataStore):
167    def __init__(self, cache_dir_prefix: str) -> None:
168        # We check startswith instead of equality because the version
169        # will have already been appended by the time the cache dir is
170        # passed here.
171        if cache_dir_prefix.startswith(os.devnull):
172            self.db = None
173            return
174
175        os.makedirs(cache_dir_prefix, exist_ok=True)
176        self.db = connect_db(os.path.join(cache_dir_prefix, 'cache.db'))
177
178    def _query(self, name: str, field: str) -> Any:
179        # Raises FileNotFound for consistency with the file system version
180        if not self.db:
181            raise FileNotFoundError()
182
183        cur = self.db.execute('SELECT {} FROM files WHERE path = ?'.format(field), (name,))
184        results = cur.fetchall()
185        if not results:
186            raise FileNotFoundError()
187        assert len(results) == 1
188        return results[0][0]
189
190    def getmtime(self, name: str) -> float:
191        return self._query(name, 'mtime')
192
193    def read(self, name: str) -> str:
194        return self._query(name, 'data')
195
196    def write(self, name: str, data: str, mtime: Optional[float] = None) -> bool:
197        import sqlite3
198
199        if not self.db:
200            return False
201        try:
202            if mtime is None:
203                mtime = time.time()
204            self.db.execute('INSERT OR REPLACE INTO files(path, mtime, data) VALUES(?, ?, ?)',
205                            (name, mtime, data))
206        except sqlite3.OperationalError:
207            return False
208        return True
209
210    def remove(self, name: str) -> None:
211        if not self.db:
212            raise FileNotFoundError()
213
214        self.db.execute('DELETE FROM files WHERE path = ?', (name,))
215
216    def commit(self) -> None:
217        if self.db:
218            self.db.commit()
219
220    def list_all(self) -> Iterable[str]:
221        if self.db:
222            for row in self.db.execute('SELECT path FROM files'):
223                yield row[0]
224