1# -*- coding: utf-8 -*-
2#    This file is part of Gtfslib-python.
3#
4#    Gtfslib-python is free software: you can redistribute it and/or modify
5#    it under the terms of the GNU General Public License as published by
6#    the Free Software Foundation, either version 3 of the License, or
7#    (at your option) any later version.
8#
9#    Gtfslib-python is distributed in the hope that it will be useful,
10#    but WITHOUT ANY WARRANTY; without even the implied warranty of
11#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12#    GNU General Public License for more details.
13#
14#    You should have received a copy of the GNU General Public License
15#    along with gtfslib-python.  If not, see <http://www.gnu.org/licenses/>.
16"""
17@author: Laurent GRÉGOIRE <laurent.gregoire@mecatran.com>
18"""
19import hashlib
20import traceback
21import logging
22import os
23import random
24import unittest
25import requests
26
27from sqlalchemy.orm import clear_mappers
28
29from gtfslib.dao import Dao
30
31# By default we do not enable this test, it takes ages
32ENABLE = False
33# Skip import of already downloaded GTFS
34# Handy for re-launching the test w/o having to redo all
35SKIP_EXISTING = True
36# Limit loading to small GTFS only. Max size in bytes
37MAX_GTFS_SIZE = 2 * 1024 * 1024
38# Local cache directory
39DIR = "all-gtfs.cache"
40# List of ID to load. If none, download the whole list
41# and process it at random.
42IDS_TO_LOAD = None
43
44# The following were known to have fancy formats that used to break:
45# Contains UTF-8 BOM:
46# IDS_TO_LOAD = [ 'janesville-transit-system' ]
47# Contain header with space
48# IDS_TO_LOAD = [ 'rseau-stan' ]
49# Breaks on non-unique stop time
50# IDS_TO_LOAD = [ 'biaostocka-komunikacja-miejska' ]
51
52class TestAllGtfs(unittest.TestCase):
53
54    def setUp(self):
55        unittest.TestCase.setUp(self)
56        clear_mappers()
57
58    # Downlaod all GTFS from GTFS data-exchange web-site
59    # and load them into a DAO.
60    def test_all_gtfs(self):
61
62        if not ENABLE:
63            print("This test is disabled as it is very time-consuming.")
64            print("If you want to enable it, please see in the code.")
65            return
66
67        # Create temporary directory if not there
68        if not os.path.isdir(DIR):
69            os.mkdir(DIR)
70
71        # Create a DAO. Re-use any existing present.
72        logging.basicConfig(level=logging.INFO)
73        dao = Dao("%s/all_gtfs.sqlite" % (DIR))
74
75        deids = IDS_TO_LOAD
76        if deids is None:
77            print("Downloading meta-info for all agencies...")
78            resource_url = "http://www.gtfs-data-exchange.com/api/agencies?format=json"
79            response = requests.get(resource_url).json()
80            if response.get('status_code') != 200:
81                raise IOError()
82            deids = []
83            for entry in response.get('data'):
84                deid = entry.get('dataexchange_id')
85                deids.append(deid)
86            # Randomize the list, otherwise we will always load ABCBus, then ...
87            random.shuffle(deids)
88
89        for deid in deids:
90            try:
91                local_filename = "%s/%s.gtfs.zip" % (DIR, deid)
92                if os.path.exists(local_filename) and SKIP_EXISTING:
93                    print("Skipping [%s], GTFS already present." % (deid))
94                    continue
95
96                print("Downloading meta-info for ID [%s]" % (deid))
97                resource_url = "http://www.gtfs-data-exchange.com/api/agency?agency=%s&format=json" % deid
98                response = requests.get(resource_url).json()
99                status_code = response.get('status_code')
100                if status_code != 200:
101                    raise IOError("Error %d (%s)" % (status_code, response.get('status_txt')))
102                data = response.get('data')
103                agency_data = data.get('agency')
104                agency_name = agency_data.get('name')
105                agency_area = agency_data.get('area')
106                agency_country = agency_data.get('country')
107
108                print("Processing [%s] %s (%s / %s)" % (deid, agency_name, agency_country, agency_area))
109                date_max = 0.0
110                file_url = None
111                file_size = 0
112                file_md5 = None
113                for datafile in data.get('datafiles'):
114                    date_added = datafile.get('date_added')
115                    if date_added > date_max:
116                        date_max = date_added
117                        file_url = datafile.get('file_url')
118                        file_size = datafile.get('size')
119                        file_md5 = datafile.get('md5sum')
120                if file_url is None:
121                    print("No datafile available, skipping.")
122                    continue
123
124                if file_size > MAX_GTFS_SIZE:
125                    print("GTFS too large (%d bytes > max %d), skipping." % (file_size, MAX_GTFS_SIZE))
126                    continue
127
128                # Check if the file is present and do not download it.
129                try:
130                    existing_md5 = hashlib.md5(open(local_filename, 'rb').read()).hexdigest()
131                except:
132                    existing_md5 = None
133                if existing_md5 == file_md5:
134                    print("Using existing file '%s': MD5 checksum matches." % (local_filename))
135                else:
136                    print("Downloading file '%s' to '%s' (%d bytes)" % (file_url, local_filename, file_size))
137                    with open(local_filename, 'wb') as local_file:
138                        cnx = requests.get(file_url, stream=True)
139                        for block in cnx.iter_content(1024):
140                            local_file.write(block)
141                    cnx.close()
142
143                feed = dao.feed(deid)
144                if feed is not None:
145                    print("Removing existing data for feed [%s]" % (deid))
146                    dao.delete_feed(deid)
147                print("Importing into DAO as ID [%s]" % (deid))
148                try:
149                    dao.load_gtfs("%s/%s.gtfs.zip" % (DIR, deid), feed_id=deid)
150                except:
151                    error_filename = "%s/%s.error" % (DIR, deid)
152                    print("Import of [%s]: FAILED. Logging error to '%s'" % (deid, error_filename))
153                    with open(error_filename, 'wb') as errfile:
154                        errfile.write(traceback.format_exc())
155                    raise
156                print("Import of [%s]: OK." % (deid))
157
158            except Exception as error:
159                logging.exception(error)
160                continue
161
162if __name__ == '__main__':
163    unittest.main()
164