1# emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*-
2# vi: set ft=python sts=4 ts=4 sw=4 et:
3### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
4#
5#   See COPYING file distributed along with the NiBabel package for the
6#   copyright and license terms.
7#
8### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
9# Copyright (C) 2011 Christian Haselgrove
10""" DICOM filesystem tools
11"""
12
13
14import os
15from os.path import join as pjoin
16import tempfile
17import getpass
18import logging
19import warnings
20import sqlite3
21
22import numpy
23
24from io import BytesIO
25
26from .nifti1 import Nifti1Header
27
28from .pydicom_compat import pydicom, read_file
29
30logger = logging.getLogger('nibabel.dft')
31
32
33class DFTError(Exception):
34    "base class for DFT exceptions"
35
36
37class CachingError(DFTError):
38    "error while caching"
39
40
41class VolumeError(DFTError):
42    "unsupported volume parameter"
43
44
45class InstanceStackError(DFTError):
46
47    "bad series of instance numbers"
48
49    def __init__(self, series, i, si):
50        self.series = series
51        self.i = i
52        self.si = si
53        return
54
55    def __str__(self):
56        fmt = 'expecting instance number %d, got %d'
57        return fmt % (self.i + 1, self.si.instance_number)
58
59
60class _Study(object):
61
62    def __init__(self, d):
63        self.uid = d['uid']
64        self.date = d['date']
65        self.time = d['time']
66        self.comments = d['comments']
67        self.patient_name = d['patient_name']
68        self.patient_id = d['patient_id']
69        self.patient_birth_date = d['patient_birth_date']
70        self.patient_sex = d['patient_sex']
71        self.series = None
72        return
73
74    def __getattribute__(self, name):
75        val = object.__getattribute__(self, name)
76        if name == 'series' and val is None:
77            val = []
78            with _db_nochange() as c:
79                c.execute("SELECT * FROM series WHERE study = ?", (self.uid, ))
80                cols = [el[0] for el in c.description]
81                for row in c:
82                    d = dict(zip(cols, row))
83                    val.append(_Series(d))
84            self.series = val
85        return val
86
87    def patient_name_or_uid(self):
88        if self.patient_name == '':
89            return self.uid
90        return self.patient_name
91
92
93class _Series(object):
94
95    def __init__(self, d):
96        self.uid = d['uid']
97        self.study = d['study']
98        self.number = d['number']
99        self.description = d['description']
100        self.rows = d['rows']
101        self.columns = d['columns']
102        self.bits_allocated = d['bits_allocated']
103        self.bits_stored = d['bits_stored']
104        self.storage_instances = None
105        return
106
107    def __getattribute__(self, name):
108        val = object.__getattribute__(self, name)
109        if name == 'storage_instances' and val is None:
110            val = []
111            with _db_nochange() as c:
112                query = """SELECT *
113                             FROM storage_instance
114                            WHERE series = ?
115                            ORDER BY instance_number"""
116                c.execute(query, (self.uid, ))
117                cols = [el[0] for el in c.description]
118                for row in c:
119                    d = dict(zip(cols, row))
120                    val.append(_StorageInstance(d))
121            self.storage_instances = val
122        return val
123
124    def as_png(self, index=None, scale_to_slice=True):
125        import PIL.Image
126        # For compatibility with older versions of PIL that did not
127        # have `frombytes`:
128        if hasattr(PIL.Image, 'frombytes'):
129            frombytes = PIL.Image.frombytes
130        else:
131            frombytes = PIL.Image.fromstring
132
133        if index is None:
134            index = len(self.storage_instances) // 2
135        d = self.storage_instances[index].dicom()
136        data = d.pixel_array.copy()
137        if self.bits_allocated != 16:
138            raise VolumeError('unsupported bits allocated')
139        if self.bits_stored != 12:
140            raise VolumeError('unsupported bits stored')
141        data = data / 16
142        if scale_to_slice:
143            min = data.min()
144            max = data.max()
145            data = data * 255 / (max - min)
146        data = data.astype(numpy.uint8)
147        im = frombytes('L', (self.rows, self.columns), data.tobytes())
148
149        s = BytesIO()
150        im.save(s, 'PNG')
151        return s.getvalue()
152
153    def png_size(self, index=None, scale_to_slice=True):
154        return len(self.as_png(index=index, scale_to_slice=scale_to_slice))
155
156    def as_nifti(self):
157        if len(self.storage_instances) < 2:
158            raise VolumeError('too few slices')
159        d = self.storage_instances[0].dicom()
160        if self.bits_allocated != 16:
161            raise VolumeError('unsupported bits allocated')
162        if self.bits_stored != 12:
163            raise VolumeError('unsupported bits stored')
164        data = numpy.ndarray((len(self.storage_instances), self.rows,
165                              self.columns), dtype=numpy.int16)
166        for (i, si) in enumerate(self.storage_instances):
167            if i + 1 != si.instance_number:
168                raise InstanceStackError(self, i, si)
169            logger.info('reading %d/%d' % (i + 1, len(self.storage_instances)))
170            d = self.storage_instances[i].dicom()
171            data[i, :, :] = d.pixel_array
172
173        d1 = self.storage_instances[0].dicom()
174        dn = self.storage_instances[-1].dicom()
175
176        pdi = d1.PixelSpacing[0]
177        pdj = d1.PixelSpacing[0]
178        pdk = d1.SpacingBetweenSlices
179
180        cosi = d1.ImageOrientationPatient[0:3]
181        cosi[0] = -1 * cosi[0]
182        cosi[1] = -1 * cosi[1]
183        cosj = d1.ImageOrientationPatient[3:6]
184        cosj[0] = -1 * cosj[0]
185        cosj[1] = -1 * cosj[1]
186
187        pos_1 = numpy.array(d1.ImagePositionPatient)
188        pos_1[0] = -1 * pos_1[0]
189        pos_1[1] = -1 * pos_1[1]
190        pos_n = numpy.array(dn.ImagePositionPatient)
191        pos_n[0] = -1 * pos_n[0]
192        pos_n[1] = -1 * pos_n[1]
193        cosk = pos_n - pos_1
194        cosk = cosk / numpy.linalg.norm(cosk)
195
196        m = ((pdi * cosi[0], pdj * cosj[0], pdk * cosk[0], pos_1[0]),
197             (pdi * cosi[1], pdj * cosj[1], pdk * cosk[1], pos_1[1]),
198             (pdi * cosi[2], pdj * cosj[2], pdk * cosk[2], pos_1[2]),
199             (0, 0, 0, 1))
200
201        # Values are python Decimals in pydicom 0.9.7
202        m = numpy.array(m, dtype=float)
203
204        hdr = Nifti1Header(endianness='<')
205        hdr.set_intent(0)
206        hdr.set_qform(m, 1)
207        hdr.set_xyzt_units(2, 8)
208        hdr.set_data_dtype(numpy.int16)
209        hdr.set_data_shape((self.columns, self.rows,
210                            len(self.storage_instances)))
211
212        s = BytesIO()
213        hdr.write_to(s)
214
215        return s.getvalue() + data.tobytes()
216
217    def nifti_size(self):
218        return 352 + 2 * len(self.storage_instances) * self.columns * self.rows
219
220
221class _StorageInstance(object):
222
223    def __init__(self, d):
224        self.uid = d['uid']
225        self.instance_number = d['instance_number']
226        self.series = d['series']
227        self.files = None
228        return
229
230    def __getattribute__(self, name):
231        val = object.__getattribute__(self, name)
232        if name == 'files' and val is None:
233            with _db_nochange() as c:
234                query = """SELECT directory, name
235                             FROM file
236                            WHERE storage_instance = ?
237                            ORDER BY directory, name"""
238                c.execute(query, (self.uid, ))
239                val = ['%s/%s' % tuple(row) for row in c]
240            self.files = val
241        return val
242
243    def dicom(self):
244        return read_file(self.files[0])
245
246
247class _db_nochange:
248    """context guard for read-only database access"""
249
250    def __enter__(self):
251        self.c = DB.cursor()
252        return self.c
253
254    def __exit__(self, type, value, traceback):
255        if type is None:
256            self.c.close()
257        DB.rollback()
258        return
259
260
261class _db_change:
262    """context guard for database access requiring a commit"""
263
264    def __enter__(self):
265        self.c = DB.cursor()
266        return self.c
267
268    def __exit__(self, type, value, traceback):
269        if type is None:
270            self.c.close()
271            DB.commit()
272        else:
273            DB.rollback()
274        return
275
276
277def _get_subdirs(base_dir, files_dict=None, followlinks=False):
278    dirs = []
279    for (dirpath, dirnames, filenames) in os.walk(base_dir, followlinks=followlinks):
280        abs_dir = os.path.realpath(dirpath)
281        if abs_dir in dirs:
282            raise CachingError(f'link cycle detected under {base_dir}')
283        dirs.append(abs_dir)
284        if files_dict is not None:
285            files_dict[abs_dir] = filenames
286    return dirs
287
288
289def update_cache(base_dir, followlinks=False):
290    mtimes = {}
291    files_by_dir = {}
292    dirs = _get_subdirs(base_dir, files_by_dir, followlinks)
293    for d in dirs:
294        os.stat(d)
295        mtimes[d] = os.stat(d).st_mtime
296    with _db_nochange() as c:
297        c.execute("SELECT path, mtime FROM directory")
298        db_mtimes = dict(c)
299        c.execute("SELECT uid FROM study")
300        studies = [row[0] for row in c]
301        c.execute("SELECT uid FROM series")
302        series = [row[0] for row in c]
303        c.execute("SELECT uid FROM storage_instance")
304        storage_instances = [row[0] for row in c]
305    with _db_change() as c:
306        for dir in sorted(mtimes.keys()):
307            if dir in db_mtimes and mtimes[dir] <= db_mtimes[dir]:
308                continue
309            logger.debug(f'updating {dir}')
310            _update_dir(c, dir, files_by_dir[dir], studies, series,
311                        storage_instances)
312            if dir in db_mtimes:
313                query = "UPDATE directory SET mtime = ? WHERE path = ?"
314                c.execute(query, (mtimes[dir], dir))
315            else:
316                query = "INSERT INTO directory (path, mtime) VALUES (?, ?)"
317                c.execute(query, (dir, mtimes[dir]))
318    return
319
320
321def get_studies(base_dir=None, followlinks=False):
322    if base_dir is not None:
323        update_cache(base_dir, followlinks)
324    if base_dir is None:
325        with _db_nochange() as c:
326            c.execute("SELECT * FROM study")
327            studies = []
328            cols = [el[0] for el in c.description]
329            for row in c:
330                d = dict(zip(cols, row))
331                studies.append(_Study(d))
332        return studies
333    query = """SELECT study
334                 FROM series
335                WHERE uid IN (SELECT series
336                                FROM storage_instance
337                               WHERE uid IN (SELECT storage_instance
338                                               FROM file
339                                              WHERE directory = ?))"""
340    with _db_nochange() as c:
341        study_uids = {}
342        for dir in _get_subdirs(base_dir, followlinks=followlinks):
343            c.execute(query, (dir, ))
344            for row in c:
345                study_uids[row[0]] = None
346        studies = []
347        for uid in study_uids:
348            c.execute("SELECT * FROM study WHERE uid = ?", (uid, ))
349            cols = [el[0] for el in c.description]
350            d = dict(zip(cols, c.fetchone()))
351            studies.append(_Study(d))
352    return studies
353
354
355def _update_dir(c, dir, files, studies, series, storage_instances):
356    logger.debug(f'Updating directory {dir}')
357    c.execute("SELECT name, mtime FROM file WHERE directory = ?", (dir, ))
358    db_mtimes = dict(c)
359    for fname in db_mtimes:
360        if fname not in files:
361            logger.debug(f'    remove {fname}')
362            c.execute("DELETE FROM file WHERE directory = ? AND name = ?",
363                      (dir, fname))
364    for fname in files:
365        mtime = os.lstat(f'{dir}/{fname}').st_mtime
366        if fname in db_mtimes and mtime <= db_mtimes[fname]:
367            logger.debug(f'    okay {fname}')
368        else:
369            logger.debug(f'    update {fname}')
370            si_uid = _update_file(c, dir, fname, studies, series,
371                                  storage_instances)
372            if fname not in db_mtimes:
373                query = """INSERT INTO file (directory,
374                                             name,
375                                             mtime,
376                                             storage_instance)
377                           VALUES (?, ?, ?, ?)"""
378                c.execute(query, (dir, fname, mtime, si_uid))
379            else:
380                query = """UPDATE file
381                              SET mtime = ?, storage_instance = ?
382                            WHERE directory = ? AND name = ?"""
383                c.execute(query, (mtime, si_uid, dir, fname))
384    return
385
386
387def _update_file(c, path, fname, studies, series, storage_instances):
388    try:
389        do = read_file(f'{path}/{fname}')
390    except pydicom.filereader.InvalidDicomError:
391        logger.debug('        not a DICOM file')
392        return None
393    try:
394        study_comments = do.StudyComments
395    except AttributeError:
396        study_comments = ''
397    try:
398        logger.debug(f'        storage instance {do.SOPInstanceUID}')
399        if str(do.StudyInstanceUID) not in studies:
400            query = """INSERT INTO study (uid,
401                                          date,
402                                          time,
403                                          comments,
404                                          patient_name,
405                                          patient_id,
406                                          patient_birth_date,
407                                          patient_sex)
408                       VALUES (?, ?, ?, ?, ?, ?, ?, ?)"""
409            params = (str(do.StudyInstanceUID),
410                      do.StudyDate,
411                      do.StudyTime,
412                      study_comments,
413                      str(do.PatientName),
414                      do.PatientID,
415                      do.PatientBirthDate,
416                      do.PatientSex)
417            c.execute(query, params)
418            studies.append(str(do.StudyInstanceUID))
419        if str(do.SeriesInstanceUID) not in series:
420            query = """INSERT INTO series (uid,
421                                           study,
422                                           number,
423                                           description,
424                                           rows,
425                                           columns,
426                                           bits_allocated,
427                                           bits_stored)
428                       VALUES (?, ?, ?, ?, ?, ?, ?, ?)"""
429            params = (str(do.SeriesInstanceUID),
430                      str(do.StudyInstanceUID),
431                      do.SeriesNumber,
432                      do.SeriesDescription,
433                      do.Rows,
434                      do.Columns,
435                      do.BitsAllocated,
436                      do.BitsStored)
437            c.execute(query, params)
438            series.append(str(do.SeriesInstanceUID))
439        if str(do.SOPInstanceUID) not in storage_instances:
440            query = """INSERT INTO storage_instance (uid, instance_number, series)
441                       VALUES (?, ?, ?)"""
442            params = (str(do.SOPInstanceUID), do.InstanceNumber,
443                      str(do.SeriesInstanceUID))
444            c.execute(query, params)
445            storage_instances.append(str(do.SOPInstanceUID))
446    except AttributeError as data:
447        logger.debug(f'        {data}')
448        return None
449    return str(do.SOPInstanceUID)
450
451
452def clear_cache():
453    with _db_change() as c:
454        c.execute("DELETE FROM file")
455        c.execute("DELETE FROM directory")
456        c.execute("DELETE FROM storage_instance")
457        c.execute("DELETE FROM series")
458        c.execute("DELETE FROM study")
459    return
460
461
462CREATE_QUERIES = (
463    """CREATE TABLE study (uid TEXT NOT NULL PRIMARY KEY,
464                           date TEXT NOT NULL,
465                           time TEXT NOT NULL,
466                           comments TEXT NOT NULL,
467                           patient_name TEXT NOT NULL,
468                           patient_id TEXT NOT NULL,
469                           patient_birth_date TEXT NOT NULL,
470                           patient_sex TEXT NOT NULL)""",
471    """CREATE TABLE series (uid TEXT NOT NULL PRIMARY KEY,
472                            study TEXT NOT NULL REFERENCES study,
473                            number TEXT NOT NULL,
474                            description TEXT NOT NULL,
475                            rows INTEGER NOT NULL,
476                            columns INTEGER NOT NULL,
477                            bits_allocated INTEGER NOT NULL,
478                            bits_stored INTEGER NOT NULL)""",
479    """CREATE TABLE storage_instance (uid TEXT NOT NULL PRIMARY KEY,
480                                      instance_number INTEGER NOT NULL,
481                                      series TEXT NOT NULL references series)""",
482    """CREATE TABLE directory (path TEXT NOT NULL PRIMARY KEY,
483                               mtime INTEGER NOT NULL)""",
484    """CREATE TABLE file (directory TEXT NOT NULL REFERENCES directory,
485                          name TEXT NOT NULL,
486                          mtime INTEGER NOT NULL,
487                          storage_instance TEXT DEFAULT NULL REFERENCES storage_instance,
488                          PRIMARY KEY (directory, name))""")
489DB_FNAME = pjoin(tempfile.gettempdir(), f'dft.{getpass.getuser()}.sqlite')
490DB = None
491
492
493def _init_db(verbose=True):
494    """ Initialize database """
495    if verbose:
496        logger.info('db filename: ' + DB_FNAME)
497    global DB
498    DB = sqlite3.connect(DB_FNAME, check_same_thread=False)
499    with _db_change() as c:
500        c.execute("SELECT COUNT(*) FROM sqlite_master WHERE type = 'table'")
501        if c.fetchone()[0] == 0:
502            logger.debug('create')
503            for q in CREATE_QUERIES:
504                c.execute(q)
505
506
507if os.name == 'nt':
508    warnings.warn('dft needs FUSE which is not available for windows')
509else:
510    _init_db()
511# eof
512