1from __future__ import absolute_import
2from __future__ import unicode_literals
3
4import datetime
5import logging
6import operator
7import re
8from functools import reduce
9
10import enum
11import six
12from docker.errors import APIError
13from docker.utils import version_lt
14
15from . import parallel
16from .config import ConfigurationError
17from .config.config import V1
18from .config.sort_services import get_container_name_from_network_mode
19from .config.sort_services import get_service_name_from_network_mode
20from .const import LABEL_ONE_OFF
21from .const import LABEL_PROJECT
22from .const import LABEL_SERVICE
23from .container import Container
24from .network import build_networks
25from .network import get_networks
26from .network import ProjectNetworks
27from .service import BuildAction
28from .service import ContainerNetworkMode
29from .service import ContainerPidMode
30from .service import ConvergenceStrategy
31from .service import NetworkMode
32from .service import parse_repository_tag
33from .service import PidMode
34from .service import Service
35from .service import ServiceNetworkMode
36from .service import ServicePidMode
37from .utils import microseconds_from_time_nano
38from .utils import truncate_string
39from .volume import ProjectVolumes
40
41
42log = logging.getLogger(__name__)
43
44
45@enum.unique
46class OneOffFilter(enum.Enum):
47    include = 0
48    exclude = 1
49    only = 2
50
51    @classmethod
52    def update_labels(cls, value, labels):
53        if value == cls.only:
54            labels.append('{0}={1}'.format(LABEL_ONE_OFF, "True"))
55        elif value == cls.exclude:
56            labels.append('{0}={1}'.format(LABEL_ONE_OFF, "False"))
57        elif value == cls.include:
58            pass
59        else:
60            raise ValueError("Invalid value for one_off: {}".format(repr(value)))
61
62
63class Project(object):
64    """
65    A collection of services.
66    """
67    def __init__(self, name, services, client, networks=None, volumes=None, config_version=None):
68        self.name = name
69        self.services = services
70        self.client = client
71        self.volumes = volumes or ProjectVolumes({})
72        self.networks = networks or ProjectNetworks({}, False)
73        self.config_version = config_version
74
75    def labels(self, one_off=OneOffFilter.exclude, legacy=False):
76        name = self.name
77        if legacy:
78            name = re.sub(r'[_-]', '', name)
79        labels = ['{0}={1}'.format(LABEL_PROJECT, name)]
80
81        OneOffFilter.update_labels(one_off, labels)
82        return labels
83
84    @classmethod
85    def from_config(cls, name, config_data, client, default_platform=None):
86        """
87        Construct a Project from a config.Config object.
88        """
89        use_networking = (config_data.version and config_data.version != V1)
90        networks = build_networks(name, config_data, client)
91        project_networks = ProjectNetworks.from_services(
92            config_data.services,
93            networks,
94            use_networking)
95        volumes = ProjectVolumes.from_config(name, config_data, client)
96        project = cls(name, [], client, project_networks, volumes, config_data.version)
97
98        for service_dict in config_data.services:
99            service_dict = dict(service_dict)
100            if use_networking:
101                service_networks = get_networks(service_dict, networks)
102            else:
103                service_networks = {}
104
105            service_dict.pop('networks', None)
106            links = project.get_links(service_dict)
107            network_mode = project.get_network_mode(
108                service_dict, list(service_networks.keys())
109            )
110            pid_mode = project.get_pid_mode(service_dict)
111            volumes_from = get_volumes_from(project, service_dict)
112
113            if config_data.version != V1:
114                service_dict['volumes'] = [
115                    volumes.namespace_spec(volume_spec)
116                    for volume_spec in service_dict.get('volumes', [])
117                ]
118
119            secrets = get_secrets(
120                service_dict['name'],
121                service_dict.pop('secrets', None) or [],
122                config_data.secrets)
123
124            project.services.append(
125                Service(
126                    service_dict.pop('name'),
127                    client=client,
128                    project=name,
129                    use_networking=use_networking,
130                    networks=service_networks,
131                    links=links,
132                    network_mode=network_mode,
133                    volumes_from=volumes_from,
134                    secrets=secrets,
135                    pid_mode=pid_mode,
136                    platform=service_dict.pop('platform', None),
137                    default_platform=default_platform,
138                    **service_dict)
139            )
140
141        return project
142
143    @property
144    def service_names(self):
145        return [service.name for service in self.services]
146
147    def get_service(self, name):
148        """
149        Retrieve a service by name. Raises NoSuchService
150        if the named service does not exist.
151        """
152        for service in self.services:
153            if service.name == name:
154                return service
155
156        raise NoSuchService(name)
157
158    def validate_service_names(self, service_names):
159        """
160        Validate that the given list of service names only contains valid
161        services. Raises NoSuchService if one of the names is invalid.
162        """
163        valid_names = self.service_names
164        for name in service_names:
165            if name not in valid_names:
166                raise NoSuchService(name)
167
168    def get_services(self, service_names=None, include_deps=False):
169        """
170        Returns a list of this project's services filtered
171        by the provided list of names, or all services if service_names is None
172        or [].
173
174        If include_deps is specified, returns a list including the dependencies for
175        service_names, in order of dependency.
176
177        Preserves the original order of self.services where possible,
178        reordering as needed to resolve dependencies.
179
180        Raises NoSuchService if any of the named services do not exist.
181        """
182        if service_names is None or len(service_names) == 0:
183            service_names = self.service_names
184
185        unsorted = [self.get_service(name) for name in service_names]
186        services = [s for s in self.services if s in unsorted]
187
188        if include_deps:
189            services = reduce(self._inject_deps, services, [])
190
191        uniques = []
192        [uniques.append(s) for s in services if s not in uniques]
193
194        return uniques
195
196    def get_services_without_duplicate(self, service_names=None, include_deps=False):
197        services = self.get_services(service_names, include_deps)
198        for service in services:
199            service.remove_duplicate_containers()
200        return services
201
202    def get_links(self, service_dict):
203        links = []
204        if 'links' in service_dict:
205            for link in service_dict.get('links', []):
206                if ':' in link:
207                    service_name, link_name = link.split(':', 1)
208                else:
209                    service_name, link_name = link, None
210                try:
211                    links.append((self.get_service(service_name), link_name))
212                except NoSuchService:
213                    raise ConfigurationError(
214                        'Service "%s" has a link to service "%s" which does not '
215                        'exist.' % (service_dict['name'], service_name))
216            del service_dict['links']
217        return links
218
219    def get_network_mode(self, service_dict, networks):
220        network_mode = service_dict.pop('network_mode', None)
221        if not network_mode:
222            if self.networks.use_networking:
223                return NetworkMode(networks[0]) if networks else NetworkMode('none')
224            return NetworkMode(None)
225
226        service_name = get_service_name_from_network_mode(network_mode)
227        if service_name:
228            return ServiceNetworkMode(self.get_service(service_name))
229
230        container_name = get_container_name_from_network_mode(network_mode)
231        if container_name:
232            try:
233                return ContainerNetworkMode(Container.from_id(self.client, container_name))
234            except APIError:
235                raise ConfigurationError(
236                    "Service '{name}' uses the network stack of container '{dep}' which "
237                    "does not exist.".format(name=service_dict['name'], dep=container_name))
238
239        return NetworkMode(network_mode)
240
241    def get_pid_mode(self, service_dict):
242        pid_mode = service_dict.pop('pid', None)
243        if not pid_mode:
244            return PidMode(None)
245
246        service_name = get_service_name_from_network_mode(pid_mode)
247        if service_name:
248            return ServicePidMode(self.get_service(service_name))
249
250        container_name = get_container_name_from_network_mode(pid_mode)
251        if container_name:
252            try:
253                return ContainerPidMode(Container.from_id(self.client, container_name))
254            except APIError:
255                raise ConfigurationError(
256                    "Service '{name}' uses the PID namespace of container '{dep}' which "
257                    "does not exist.".format(name=service_dict['name'], dep=container_name)
258                )
259
260        return PidMode(pid_mode)
261
262    def start(self, service_names=None, **options):
263        containers = []
264
265        def start_service(service):
266            service_containers = service.start(quiet=True, **options)
267            containers.extend(service_containers)
268
269        services = self.get_services(service_names)
270
271        def get_deps(service):
272            return {
273                (self.get_service(dep), config)
274                for dep, config in service.get_dependency_configs().items()
275            }
276
277        parallel.parallel_execute(
278            services,
279            start_service,
280            operator.attrgetter('name'),
281            'Starting',
282            get_deps,
283            fail_check=lambda obj: not obj.containers(),
284        )
285
286        return containers
287
288    def stop(self, service_names=None, one_off=OneOffFilter.exclude, **options):
289        containers = self.containers(service_names, one_off=one_off)
290
291        def get_deps(container):
292            # actually returning inversed dependencies
293            return {(other, None) for other in containers
294                    if container.service in
295                    self.get_service(other.service).get_dependency_names()}
296
297        parallel.parallel_execute(
298            containers,
299            self.build_container_operation_with_timeout_func('stop', options),
300            operator.attrgetter('name'),
301            'Stopping',
302            get_deps,
303        )
304
305    def pause(self, service_names=None, **options):
306        containers = self.containers(service_names)
307        parallel.parallel_pause(reversed(containers), options)
308        return containers
309
310    def unpause(self, service_names=None, **options):
311        containers = self.containers(service_names)
312        parallel.parallel_unpause(containers, options)
313        return containers
314
315    def kill(self, service_names=None, **options):
316        parallel.parallel_kill(self.containers(service_names), options)
317
318    def remove_stopped(self, service_names=None, one_off=OneOffFilter.exclude, **options):
319        parallel.parallel_remove(self.containers(
320            service_names, stopped=True, one_off=one_off
321        ), options)
322
323    def down(
324            self,
325            remove_image_type,
326            include_volumes,
327            remove_orphans=False,
328            timeout=None,
329            ignore_orphans=False):
330        self.stop(one_off=OneOffFilter.include, timeout=timeout)
331        if not ignore_orphans:
332            self.find_orphan_containers(remove_orphans)
333        self.remove_stopped(v=include_volumes, one_off=OneOffFilter.include)
334
335        self.networks.remove()
336
337        if include_volumes:
338            self.volumes.remove()
339
340        self.remove_images(remove_image_type)
341
342    def remove_images(self, remove_image_type):
343        for service in self.get_services():
344            service.remove_image(remove_image_type)
345
346    def restart(self, service_names=None, **options):
347        containers = self.containers(service_names, stopped=True)
348
349        parallel.parallel_execute(
350            containers,
351            self.build_container_operation_with_timeout_func('restart', options),
352            operator.attrgetter('name'),
353            'Restarting',
354        )
355        return containers
356
357    def build(self, service_names=None, no_cache=False, pull=False, force_rm=False, memory=None,
358              build_args=None, gzip=False, parallel_build=False):
359
360        services = []
361        for service in self.get_services(service_names):
362            if service.can_be_built():
363                services.append(service)
364            else:
365                log.info('%s uses an image, skipping' % service.name)
366
367        def build_service(service):
368            service.build(no_cache, pull, force_rm, memory, build_args, gzip)
369
370        if parallel_build:
371            _, errors = parallel.parallel_execute(
372                services,
373                build_service,
374                operator.attrgetter('name'),
375                'Building',
376                limit=5,
377            )
378            if len(errors):
379                combined_errors = '\n'.join([
380                    e.decode('utf-8') if isinstance(e, six.binary_type) else e for e in errors.values()
381                ])
382                raise ProjectError(combined_errors)
383
384        else:
385            for service in services:
386                build_service(service)
387
388    def create(
389        self,
390        service_names=None,
391        strategy=ConvergenceStrategy.changed,
392        do_build=BuildAction.none,
393    ):
394        services = self.get_services_without_duplicate(service_names, include_deps=True)
395
396        for svc in services:
397            svc.ensure_image_exists(do_build=do_build)
398        plans = self._get_convergence_plans(services, strategy)
399
400        for service in services:
401            service.execute_convergence_plan(
402                plans[service.name],
403                detached=True,
404                start=False)
405
406    def _legacy_event_processor(self, service_names):
407        # Only for v1 files or when Compose is forced to use an older API version
408        def build_container_event(event, container):
409            time = datetime.datetime.fromtimestamp(event['time'])
410            time = time.replace(
411                microsecond=microseconds_from_time_nano(event['timeNano'])
412            )
413            return {
414                'time': time,
415                'type': 'container',
416                'action': event['status'],
417                'id': container.id,
418                'service': container.service,
419                'attributes': {
420                    'name': container.name,
421                    'image': event['from'],
422                },
423                'container': container,
424            }
425
426        service_names = set(service_names or self.service_names)
427        for event in self.client.events(
428            filters={'label': self.labels()},
429            decode=True
430        ):
431            # This is a guard against some events broadcasted by swarm that
432            # don't have a status field.
433            # See https://github.com/docker/compose/issues/3316
434            if 'status' not in event:
435                continue
436
437            try:
438                # this can fail if the container has been removed or if the event
439                # refers to an image
440                container = Container.from_id(self.client, event['id'])
441            except APIError:
442                continue
443            if container.service not in service_names:
444                continue
445            yield build_container_event(event, container)
446
447    def events(self, service_names=None):
448        if version_lt(self.client.api_version, '1.22'):
449            # New, better event API was introduced in 1.22.
450            return self._legacy_event_processor(service_names)
451
452        def build_container_event(event):
453            container_attrs = event['Actor']['Attributes']
454            time = datetime.datetime.fromtimestamp(event['time'])
455            time = time.replace(
456                microsecond=microseconds_from_time_nano(event['timeNano'])
457            )
458
459            container = None
460            try:
461                container = Container.from_id(self.client, event['id'])
462            except APIError:
463                # Container may have been removed (e.g. if this is a destroy event)
464                pass
465
466            return {
467                'time': time,
468                'type': 'container',
469                'action': event['status'],
470                'id': event['Actor']['ID'],
471                'service': container_attrs.get(LABEL_SERVICE),
472                'attributes': dict([
473                    (k, v) for k, v in container_attrs.items()
474                    if not k.startswith('com.docker.compose.')
475                ]),
476                'container': container,
477            }
478
479        def yield_loop(service_names):
480            for event in self.client.events(
481                filters={'label': self.labels()},
482                decode=True
483            ):
484                # TODO: support other event types
485                if event.get('Type') != 'container':
486                    continue
487
488                try:
489                    if event['Actor']['Attributes'][LABEL_SERVICE] not in service_names:
490                        continue
491                except KeyError:
492                    continue
493                yield build_container_event(event)
494
495        return yield_loop(set(service_names) if service_names else self.service_names)
496
497    def up(self,
498           service_names=None,
499           start_deps=True,
500           strategy=ConvergenceStrategy.changed,
501           do_build=BuildAction.none,
502           timeout=None,
503           detached=False,
504           remove_orphans=False,
505           ignore_orphans=False,
506           scale_override=None,
507           rescale=True,
508           start=True,
509           always_recreate_deps=False,
510           reset_container_image=False,
511           renew_anonymous_volumes=False,
512           silent=False,
513           ):
514
515        self.initialize()
516        if not ignore_orphans:
517            self.find_orphan_containers(remove_orphans)
518
519        if scale_override is None:
520            scale_override = {}
521
522        services = self.get_services_without_duplicate(
523            service_names,
524            include_deps=start_deps)
525
526        for svc in services:
527            svc.ensure_image_exists(do_build=do_build, silent=silent)
528        plans = self._get_convergence_plans(
529            services, strategy, always_recreate_deps=always_recreate_deps)
530
531        def do(service):
532
533            return service.execute_convergence_plan(
534                plans[service.name],
535                timeout=timeout,
536                detached=detached,
537                scale_override=scale_override.get(service.name),
538                rescale=rescale,
539                start=start,
540                reset_container_image=reset_container_image,
541                renew_anonymous_volumes=renew_anonymous_volumes,
542            )
543
544        def get_deps(service):
545            return {
546                (self.get_service(dep), config)
547                for dep, config in service.get_dependency_configs().items()
548            }
549
550        results, errors = parallel.parallel_execute(
551            services,
552            do,
553            operator.attrgetter('name'),
554            None,
555            get_deps,
556        )
557        if errors:
558            raise ProjectError(
559                'Encountered errors while bringing up the project.'
560            )
561
562        return [
563            container
564            for svc_containers in results
565            if svc_containers is not None
566            for container in svc_containers
567        ]
568
569    def initialize(self):
570        self.networks.initialize()
571        self.volumes.initialize()
572
573    def _get_convergence_plans(self, services, strategy, always_recreate_deps=False):
574        plans = {}
575
576        for service in services:
577            updated_dependencies = [
578                name
579                for name in service.get_dependency_names()
580                if name in plans and
581                plans[name].action in ('recreate', 'create')
582            ]
583
584            if updated_dependencies and strategy.allows_recreate:
585                log.debug('%s has upstream changes (%s)',
586                          service.name,
587                          ", ".join(updated_dependencies))
588                containers_stopped = any(
589                    service.containers(stopped=True, filters={'status': ['created', 'exited']}))
590                has_links = any(c.get('HostConfig.Links') for c in service.containers())
591                if always_recreate_deps or containers_stopped or not has_links:
592                    plan = service.convergence_plan(ConvergenceStrategy.always)
593                else:
594                    plan = service.convergence_plan(strategy)
595            else:
596                plan = service.convergence_plan(strategy)
597
598            plans[service.name] = plan
599
600        return plans
601
602    def pull(self, service_names=None, ignore_pull_failures=False, parallel_pull=False, silent=False,
603             include_deps=False):
604        services = self.get_services(service_names, include_deps)
605        msg = not silent and 'Pulling' or None
606
607        if parallel_pull:
608            def pull_service(service):
609                strm = service.pull(ignore_pull_failures, True, stream=True)
610                if strm is None:  # Attempting to pull service with no `image` key is a no-op
611                    return
612
613                writer = parallel.get_stream_writer()
614
615                for event in strm:
616                    if 'status' not in event:
617                        continue
618                    status = event['status'].lower()
619                    if 'progressDetail' in event:
620                        detail = event['progressDetail']
621                        if 'current' in detail and 'total' in detail:
622                            percentage = float(detail['current']) / float(detail['total'])
623                            status = '{} ({:.1%})'.format(status, percentage)
624
625                    writer.write(
626                        msg, service.name, truncate_string(status), lambda s: s
627                    )
628
629            _, errors = parallel.parallel_execute(
630                services,
631                pull_service,
632                operator.attrgetter('name'),
633                msg,
634                limit=5,
635            )
636            if len(errors):
637                combined_errors = '\n'.join([
638                    e.decode('utf-8') if isinstance(e, six.binary_type) else e for e in errors.values()
639                ])
640                raise ProjectError(combined_errors)
641
642        else:
643            for service in services:
644                service.pull(ignore_pull_failures, silent=silent)
645
646    def push(self, service_names=None, ignore_push_failures=False):
647        unique_images = set()
648        for service in self.get_services(service_names, include_deps=False):
649            # Considering <image> and <image:latest> as the same
650            repo, tag, sep = parse_repository_tag(service.image_name)
651            service_image_name = sep.join((repo, tag)) if tag else sep.join((repo, 'latest'))
652
653            if service_image_name not in unique_images:
654                service.push(ignore_push_failures)
655                unique_images.add(service_image_name)
656
657    def _labeled_containers(self, stopped=False, one_off=OneOffFilter.exclude):
658        ctnrs = list(filter(None, [
659            Container.from_ps(self.client, container)
660            for container in self.client.containers(
661                all=stopped,
662                filters={'label': self.labels(one_off=one_off)})])
663        )
664        if ctnrs:
665            return ctnrs
666
667        return list(filter(lambda c: c.has_legacy_proj_name(self.name), filter(None, [
668            Container.from_ps(self.client, container)
669            for container in self.client.containers(
670                all=stopped,
671                filters={'label': self.labels(one_off=one_off, legacy=True)})])
672        ))
673
674    def containers(self, service_names=None, stopped=False, one_off=OneOffFilter.exclude):
675        if service_names:
676            self.validate_service_names(service_names)
677        else:
678            service_names = self.service_names
679
680        containers = self._labeled_containers(stopped, one_off)
681
682        def matches_service_names(container):
683            return container.labels.get(LABEL_SERVICE) in service_names
684
685        return [c for c in containers if matches_service_names(c)]
686
687    def find_orphan_containers(self, remove_orphans):
688        def _find():
689            containers = self._labeled_containers()
690            for ctnr in containers:
691                service_name = ctnr.labels.get(LABEL_SERVICE)
692                if service_name not in self.service_names:
693                    yield ctnr
694        orphans = list(_find())
695        if not orphans:
696            return
697        if remove_orphans:
698            for ctnr in orphans:
699                log.info('Removing orphan container "{0}"'.format(ctnr.name))
700                ctnr.kill()
701                ctnr.remove(force=True)
702        else:
703            log.warning(
704                'Found orphan containers ({0}) for this project. If '
705                'you removed or renamed this service in your compose '
706                'file, you can run this command with the '
707                '--remove-orphans flag to clean it up.'.format(
708                    ', '.join(["{}".format(ctnr.name) for ctnr in orphans])
709                )
710            )
711
712    def _inject_deps(self, acc, service):
713        dep_names = service.get_dependency_names()
714
715        if len(dep_names) > 0:
716            dep_services = self.get_services(
717                service_names=list(set(dep_names)),
718                include_deps=True
719            )
720        else:
721            dep_services = []
722
723        dep_services.append(service)
724        return acc + dep_services
725
726    def build_container_operation_with_timeout_func(self, operation, options):
727        def container_operation_with_timeout(container):
728            if options.get('timeout') is None:
729                service = self.get_service(container.service)
730                options['timeout'] = service.stop_timeout(None)
731            return getattr(container, operation)(**options)
732        return container_operation_with_timeout
733
734
735def get_volumes_from(project, service_dict):
736    volumes_from = service_dict.pop('volumes_from', None)
737    if not volumes_from:
738        return []
739
740    def build_volume_from(spec):
741        if spec.type == 'service':
742            try:
743                return spec._replace(source=project.get_service(spec.source))
744            except NoSuchService:
745                pass
746
747        if spec.type == 'container':
748            try:
749                container = Container.from_id(project.client, spec.source)
750                return spec._replace(source=container)
751            except APIError:
752                pass
753
754        raise ConfigurationError(
755            "Service \"{}\" mounts volumes from \"{}\", which is not the name "
756            "of a service or container.".format(
757                service_dict['name'],
758                spec.source))
759
760    return [build_volume_from(vf) for vf in volumes_from]
761
762
763def get_secrets(service, service_secrets, secret_defs):
764    secrets = []
765
766    for secret in service_secrets:
767        secret_def = secret_defs.get(secret.source)
768        if not secret_def:
769            raise ConfigurationError(
770                "Service \"{service}\" uses an undefined secret \"{secret}\" "
771                .format(service=service, secret=secret.source))
772
773        if secret_def.get('external'):
774            log.warn("Service \"{service}\" uses secret \"{secret}\" which is external. "
775                     "External secrets are not available to containers created by "
776                     "docker-compose.".format(service=service, secret=secret.source))
777            continue
778
779        if secret.uid or secret.gid or secret.mode:
780            log.warn(
781                "Service \"{service}\" uses secret \"{secret}\" with uid, "
782                "gid, or mode. These fields are not supported by this "
783                "implementation of the Compose file".format(
784                    service=service, secret=secret.source
785                )
786            )
787
788        secrets.append({'secret': secret, 'file': secret_def.get('file')})
789
790    return secrets
791
792
793class NoSuchService(Exception):
794    def __init__(self, name):
795        if isinstance(name, six.binary_type):
796            name = name.decode('utf-8')
797        self.name = name
798        self.msg = "No such service: %s" % self.name
799
800    def __str__(self):
801        return self.msg
802
803
804class ProjectError(Exception):
805    def __init__(self, msg):
806        self.msg = msg
807