1import logging
2import os
3import random
4import shutil
5import string
6import sys
7import tempfile
8import time
9from datetime import timedelta, datetime, date
10from elasticsearch import Elasticsearch
11from elasticsearch.exceptions import ConnectionError
12from subprocess import Popen, PIPE
13from curator import get_version
14
15from . import testvars as testvars
16from unittest import SkipTest, TestCase
17from mock import Mock
18
19client = None
20
21DATEMAP = {
22    'months': '%Y.%m',
23    'weeks': '%Y.%W',
24    'days': '%Y.%m.%d',
25    'hours': '%Y.%m.%d.%H',
26}
27
28host, port = os.environ.get('TEST_ES_SERVER', 'localhost:9200').split(':')
29port = int(port) if port else 9200
30
31def random_directory():
32    dirname = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(8))
33    directory = tempfile.mkdtemp(suffix=dirname)
34    if not os.path.exists(directory):
35        os.makedirs(directory)
36    return directory
37
38def get_client():
39    global client
40    if client is not None:
41        return client
42
43    client = Elasticsearch([os.environ.get('TEST_ES_SERVER', {})], timeout=300)
44
45    # wait for yellow status
46    for _ in range(100):
47        time.sleep(.1)
48        try:
49            # pylint: disable=E1123
50            client.cluster.health(wait_for_status='yellow')
51            return client
52        except ConnectionError:
53            continue
54    else:
55        # timeout
56        raise SkipTest("Elasticsearch failed to start.")
57
58def setup():
59    get_client()
60
61class Args(dict):
62    def __getattr__(self, att_name):
63        return self.get(att_name, None)
64
65class CuratorTestCase(TestCase):
66    def setUp(self):
67        super(CuratorTestCase, self).setUp()
68        self.logger = logging.getLogger('CuratorTestCase.setUp')
69        self.client = get_client()
70
71        args = {}
72        args['host'], args['port'] = host, port
73        args['time_unit'] = 'days'
74        args['prefix'] = 'logstash-'
75        self.args = args
76        # dirname = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(8))
77        # This will create a psuedo-random temporary directory on the machine
78        # which runs the unit tests, but NOT on the machine where elasticsearch
79        # is running. This means tests may fail if run against remote instances
80        # unless you explicitly set `self.args['location']` to a proper spot
81        # on the target machine.
82        self.args['location'] = random_directory()
83        self.args['configdir'] = random_directory()
84        self.args['configfile'] = os.path.join(self.args['configdir'], 'curator.yml')
85        self.args['actionfile'] = os.path.join(self.args['configdir'], 'actions.yml')
86        self.args['repository'] = 'TEST_REPOSITORY'
87        # if not os.path.exists(self.args['location']):
88        #     os.makedirs(self.args['location'])
89        self.logger.debug('setUp completed...')
90
91    def tearDown(self):
92        self.logger = logging.getLogger('CuratorTestCase.tearDown')
93        self.logger.debug('tearDown initiated...')
94        self.delete_repositories()
95        self.client.indices.delete(index='*')
96        # pylint: disable=E1123
97        self.client.indices.delete_template(name='*', ignore=404)
98        for path_arg in ['location', 'configdir']:
99            if os.path.exists(self.args[path_arg]):
100                shutil.rmtree(self.args[path_arg])
101
102    def parse_args(self):
103        return Args(self.args)
104
105    def create_indices(self, count, unit=None, ilm_policy=None):
106        now = datetime.utcnow()
107        unit = unit if unit else self.args['time_unit']
108        format = DATEMAP[unit]
109        if not unit == 'months':
110            step = timedelta(**{unit: 1})
111            for _ in range(count):
112                self.create_index(self.args['prefix'] + now.strftime(format), wait_for_yellow=False, ilm_policy=ilm_policy)
113                now -= step
114        else: # months
115            now = date.today()
116            d = date(now.year, now.month, 1)
117            self.create_index(self.args['prefix'] + now.strftime(format), wait_for_yellow=False, ilm_policy=ilm_policy)
118
119            for _ in range(1, count):
120                if d.month == 1:
121                    d = date(d.year-1, 12, 1)
122                else:
123                    d = date(d.year, d.month-1, 1)
124                self.create_index(self.args['prefix'] + datetime(d.year, d.month, 1).strftime(format), wait_for_yellow=False, ilm_policy=ilm_policy)
125        # pylint: disable=E1123
126        self.client.cluster.health(wait_for_status='yellow')
127
128    def wfy(self):
129        # pylint: disable=E1123
130        self.client.cluster.health(wait_for_status='yellow')
131
132    def create_index(self, name, shards=1, wait_for_yellow=True, ilm_policy=None, wait_for_active_shards=1):
133        request_body={'settings': {'number_of_shards': shards, 'number_of_replicas': 0}}
134        if ilm_policy is not None:
135            request_body['settings']['index'] = {'lifecycle': {'name': ilm_policy}}
136        self.client.indices.create(index=name, body=request_body, wait_for_active_shards=wait_for_active_shards)
137        if wait_for_yellow:
138            self.wfy()
139
140    def add_docs(self, idx):
141        for i in ["1", "2", "3"]:
142            ver = get_version(self.client)
143            if ver >= (7, 0, 0):
144                self.client.create(
145                    index=idx, doc_type='_doc', id=i, body={"doc" + i :'TEST DOCUMENT'})
146            else:
147                self.client.create(
148                    index=idx, doc_type='doc', id=i, body={"doc" + i :'TEST DOCUMENT'})
149            # This should force each doc to be in its own segment.
150            # pylint: disable=E1123
151            self.client.indices.flush(index=idx, force=True)
152            self.client.indices.refresh(index=idx)
153
154    def create_snapshot(self, name, csv_indices):
155        body = {
156            "indices": csv_indices,
157            "ignore_unavailable": False,
158            "include_global_state": True,
159            "partial": False,
160        }
161        self.create_repository()
162        # pylint: disable=E1123
163        self.client.snapshot.create(
164            repository=self.args['repository'], snapshot=name, body=body,
165            wait_for_completion=True
166        )
167
168    def create_repository(self):
169        body = {'type':'fs', 'settings':{'location':self.args['location']}}
170        self.client.snapshot.create_repository(repository=self.args['repository'], body=body)
171
172    def delete_repositories(self):
173        result = self.client.snapshot.get_repository(repository='_all')
174        for repo in result:
175            self.client.snapshot.delete_repository(repository=repo)
176
177    def close_index(self, name):
178        self.client.indices.close(index=name)
179
180    def write_config(self, fname, data):
181        with open(fname, 'w') as f:
182            f.write(data)
183
184    def get_runner_args(self):
185        self.write_config(self.args['configfile'], testvars.client_config.format(host, port))
186        runner = os.path.join(os.getcwd(), 'run_singleton.py')
187        return [ sys.executable, runner ]
188
189    def run_subprocess(self, args, logname='subprocess'):
190        logger = logging.getLogger(logname)
191        p = Popen(args, stderr=PIPE, stdout=PIPE)
192        stdout, stderr = p.communicate()
193        logger.debug('STDOUT = {0}'.format(stdout.decode('utf-8')))
194        logger.debug('STDERR = {0}'.format(stderr.decode('utf-8')))
195        return p.returncode
196