1from __future__ import unicode_literals
2
3from dvc.utils.compat import str, open
4
5import os
6import yaml
7import subprocess
8
9from dvc.utils.fs import contains_symlink_up_to
10from schema import Schema, SchemaError, Optional, Or, And
11
12import dvc.prompt as prompt
13import dvc.logger as logger
14import dvc.dependency as dependency
15import dvc.output as output
16from dvc.exceptions import DvcException
17from dvc.utils import dict_md5, fix_env, load_stage_file
18
19
20class StageCmdFailedError(DvcException):
21    def __init__(self, stage):
22        msg = "stage '{}' cmd {} failed".format(stage.relpath, stage.cmd)
23        super(StageCmdFailedError, self).__init__(msg)
24
25
26class StageFileFormatError(DvcException):
27    def __init__(self, fname, e):
28        msg = "stage file '{}' format error: {}".format(fname, str(e))
29        super(StageFileFormatError, self).__init__(msg)
30
31
32class StageFileDoesNotExistError(DvcException):
33    def __init__(self, fname):
34        msg = "'{}' does not exist.".format(fname)
35
36        sname = fname + Stage.STAGE_FILE_SUFFIX
37        if Stage.is_stage_file(sname):
38            msg += " Do you mean '{}'?".format(sname)
39
40        super(StageFileDoesNotExistError, self).__init__(msg)
41
42
43class StageFileAlreadyExistsError(DvcException):
44    def __init__(self, relpath):
45        msg = "stage '{}' already exists".format(relpath)
46        super(StageFileAlreadyExistsError, self).__init__(msg)
47
48
49class StageFileIsNotDvcFileError(DvcException):
50    def __init__(self, fname):
51        msg = "'{}' is not a dvc file".format(fname)
52
53        sname = fname + Stage.STAGE_FILE_SUFFIX
54        if Stage.is_stage_file(sname):
55            msg += " Do you mean '{}'?".format(sname)
56
57        super(StageFileIsNotDvcFileError, self).__init__(msg)
58
59
60class StageFileBadNameError(DvcException):
61    def __init__(self, msg):
62        super(StageFileBadNameError, self).__init__(msg)
63
64
65class StagePathOutsideError(DvcException):
66    def __init__(self, path):
67        msg = "stage working or file path '{}' is outside of dvc repo"
68        super(StagePathOutsideError, self).__init__(msg.format(path))
69
70
71class StagePathNotFoundError(DvcException):
72    def __init__(self, path):
73        msg = "stage working or file path '{}' does not exist"
74        super(StagePathNotFoundError, self).__init__(msg.format(path))
75
76
77class StagePathNotDirectoryError(DvcException):
78    def __init__(self, path):
79        msg = "stage working or file path '{}' is not directory"
80        super(StagePathNotDirectoryError, self).__init__(msg.format(path))
81
82
83class StageCommitError(DvcException):
84    pass
85
86
87class MissingDep(DvcException):
88    def __init__(self, deps):
89        assert len(deps) > 0
90
91        if len(deps) > 1:
92            dep = "dependencies"
93        else:
94            dep = "dependency"
95
96        msg = "missing {}: {}".format(dep, ", ".join(map(str, deps)))
97        super(MissingDep, self).__init__(msg)
98
99
100class MissingDataSource(DvcException):
101    def __init__(self, missing_files):
102        assert len(missing_files) > 0
103
104        source = "source"
105        if len(missing_files) > 1:
106            source += "s"
107
108        msg = "missing data {}: {}".format(source, ", ".join(missing_files))
109        super(MissingDataSource, self).__init__(msg)
110
111
112class Stage(object):
113    STAGE_FILE = "Dvcfile"
114    STAGE_FILE_SUFFIX = ".dvc"
115
116    PARAM_MD5 = "md5"
117    PARAM_CMD = "cmd"
118    PARAM_WDIR = "wdir"
119    PARAM_DEPS = "deps"
120    PARAM_OUTS = "outs"
121    PARAM_LOCKED = "locked"
122
123    SCHEMA = {
124        Optional(PARAM_MD5): Or(str, None),
125        Optional(PARAM_CMD): Or(str, None),
126        Optional(PARAM_WDIR): Or(str, None),
127        Optional(PARAM_DEPS): Or(And(list, Schema([dependency.SCHEMA])), None),
128        Optional(PARAM_OUTS): Or(And(list, Schema([output.SCHEMA])), None),
129        Optional(PARAM_LOCKED): bool,
130    }
131
132    def __init__(
133        self,
134        repo,
135        path=None,
136        cmd=None,
137        wdir=os.curdir,
138        deps=None,
139        outs=None,
140        md5=None,
141        locked=False,
142    ):
143        if deps is None:
144            deps = []
145        if outs is None:
146            outs = []
147
148        self.repo = repo
149        self.path = path
150        self.cmd = cmd
151        self.wdir = wdir
152        self.outs = outs
153        self.deps = deps
154        self.md5 = md5
155        self.locked = locked
156
157    def __repr__(self):
158        return "Stage: '{path}'".format(
159            path=self.relpath if self.path else "No path"
160        )
161
162    @property
163    def relpath(self):
164        return os.path.relpath(self.path)
165
166    @property
167    def is_data_source(self):
168        """Whether the stage file was created with `dvc add` or `dvc import`"""
169        return self.cmd is None
170
171    @staticmethod
172    def is_valid_filename(path):
173        return (
174            path.endswith(Stage.STAGE_FILE_SUFFIX)
175            or os.path.basename(path) == Stage.STAGE_FILE
176        )
177
178    @staticmethod
179    def is_stage_file(path):
180        return os.path.isfile(path) and Stage.is_valid_filename(path)
181
182    def changed_md5(self):
183        return self.md5 != self._compute_md5()
184
185    @property
186    def is_callback(self):
187        """
188        A callback stage is always considered as changed,
189        so it runs on every `dvc repro` call.
190        """
191        return not self.is_data_source and len(self.deps) == 0
192
193    @property
194    def is_import(self):
195        """Whether the stage file was created with `dvc import`."""
196        return not self.cmd and len(self.deps) == 1 and len(self.outs) == 1
197
198    def _changed_deps(self):
199        if self.locked:
200            return False
201
202        if self.is_callback:
203            logger.warning(
204                "Dvc file '{fname}' is a 'callback' stage "
205                "(has a command and no dependencies) and thus always "
206                "considered as changed.".format(fname=self.relpath)
207            )
208            return True
209
210        for dep in self.deps:
211            if dep.changed():
212                logger.warning(
213                    "Dependency '{dep}' of '{stage}' changed.".format(
214                        dep=dep, stage=self.relpath
215                    )
216                )
217                return True
218
219        return False
220
221    def _changed_outs(self):
222        for out in self.outs:
223            if out.changed():
224                logger.warning(
225                    "Output '{out}' of '{stage}' changed.".format(
226                        out=out, stage=self.relpath
227                    )
228                )
229                return True
230
231        return False
232
233    def _changed_md5(self):
234        if self.changed_md5():
235            logger.warning("Dvc file '{}' changed.".format(self.relpath))
236            return True
237        return False
238
239    def changed(self):
240        ret = any(
241            [self._changed_deps(), self._changed_outs(), self._changed_md5()]
242        )
243
244        if ret:
245            msg = "Stage '{}' changed.".format(self.relpath)
246            color = "yellow"
247        else:
248            msg = "Stage '{}' didn't change.".format(self.relpath)
249            color = "green"
250
251        logger.info(logger.colorize(msg, color))
252
253        return ret
254
255    def remove_outs(self, ignore_remove=False):
256        """
257        Used mainly for `dvc remove --outs`
258        """
259        for out in self.outs:
260            out.remove(ignore_remove=ignore_remove)
261
262    def unprotect_outs(self):
263        for out in self.outs:
264            if out.scheme != "local" or not out.exists:
265                continue
266            self.repo.unprotect(out.path)
267
268    def remove(self):
269        self.remove_outs(ignore_remove=True)
270        os.unlink(self.path)
271
272    def reproduce(
273        self, force=False, dry=False, interactive=False, no_commit=False
274    ):
275        if not self.changed() and not force:
276            return None
277
278        if (self.cmd or self.is_import) and not self.locked and not dry:
279            # Removing outputs only if we actually have command to reproduce
280            self.remove_outs(ignore_remove=False)
281
282        msg = (
283            "Going to reproduce '{stage}'. "
284            "Are you sure you want to continue?".format(stage=self.relpath)
285        )
286
287        if interactive and not prompt.confirm(msg):
288            raise DvcException("reproduction aborted by the user")
289
290        logger.info("Reproducing '{stage}'".format(stage=self.relpath))
291
292        self.run(dry=dry, no_commit=no_commit)
293
294        logger.debug("'{stage}' was reproduced".format(stage=self.relpath))
295
296        return self
297
298    @staticmethod
299    def validate(d, fname=None):
300        from dvc.utils import convert_to_unicode
301
302        try:
303            Schema(Stage.SCHEMA).validate(convert_to_unicode(d))
304        except SchemaError as exc:
305            raise StageFileFormatError(fname, exc)
306
307    @classmethod
308    def _stage_fname(cls, fname, outs, add):
309        if fname:
310            return fname
311
312        if not outs:
313            return cls.STAGE_FILE
314
315        out = outs[0]
316        path_handler = out.remote.ospath
317
318        fname = path_handler.basename(out.path) + cls.STAGE_FILE_SUFFIX
319
320        fname = Stage._expand_to_path_on_add_local(
321            add, fname, out, path_handler
322        )
323
324        return fname
325
326    @staticmethod
327    def _expand_to_path_on_add_local(add, fname, out, path_handler):
328        if (
329            add
330            and out.is_local
331            and not contains_symlink_up_to(out.path, out.repo.root_dir)
332        ):
333            fname = path_handler.join(path_handler.dirname(out.path), fname)
334        return fname
335
336    @staticmethod
337    def _check_stage_path(repo, path):
338        assert repo is not None
339
340        real_path = os.path.realpath(path)
341        if not os.path.exists(real_path):
342            raise StagePathNotFoundError(path)
343
344        if not os.path.isdir(real_path):
345            raise StagePathNotDirectoryError(path)
346
347        proj_dir = os.path.realpath(repo.root_dir) + os.path.sep
348        if not (real_path + os.path.sep).startswith(proj_dir):
349            raise StagePathOutsideError(path)
350
351    @property
352    def is_cached(self):
353        """
354        Checks if this stage has been already ran and stored
355        """
356        from dvc.remote.local import RemoteLOCAL
357        from dvc.remote.s3 import RemoteS3
358
359        old = Stage.load(self.repo, self.path)
360        if old._changed_outs():
361            return False
362
363        # NOTE: need to save checksums for deps in order to compare them
364        # with what is written in the old stage.
365        for dep in self.deps:
366            dep.save()
367
368        old_d = old.dumpd()
369        new_d = self.dumpd()
370
371        # NOTE: need to remove checksums from old dict in order to compare
372        # it to the new one, since the new one doesn't have checksums yet.
373        old_d.pop(self.PARAM_MD5, None)
374        new_d.pop(self.PARAM_MD5, None)
375        outs = old_d.get(self.PARAM_OUTS, [])
376        for out in outs:
377            out.pop(RemoteLOCAL.PARAM_CHECKSUM, None)
378            out.pop(RemoteS3.PARAM_CHECKSUM, None)
379
380        return old_d == new_d
381
382    @staticmethod
383    def create(
384        repo=None,
385        cmd=None,
386        deps=None,
387        outs=None,
388        outs_no_cache=None,
389        metrics=None,
390        metrics_no_cache=None,
391        fname=None,
392        cwd=None,
393        wdir=None,
394        locked=False,
395        add=False,
396        overwrite=True,
397        ignore_build_cache=False,
398        remove_outs=False,
399    ):
400        if outs is None:
401            outs = []
402        if deps is None:
403            deps = []
404        if outs_no_cache is None:
405            outs_no_cache = []
406        if metrics is None:
407            metrics = []
408        if metrics_no_cache is None:
409            metrics_no_cache = []
410
411        # Backward compatibility for `cwd` option
412        if wdir is None and cwd is not None:
413            if fname is not None and os.path.basename(fname) != fname:
414                raise StageFileBadNameError(
415                    "stage file name '{fname}' may not contain subdirectories"
416                    " if '-c|--cwd' (deprecated) is specified. Use '-w|--wdir'"
417                    " along with '-f' to specify stage file path and working"
418                    " directory.".format(fname=fname)
419                )
420            wdir = cwd
421        else:
422            wdir = os.curdir if wdir is None else wdir
423
424        stage = Stage(repo=repo, wdir=wdir, cmd=cmd, locked=locked)
425
426        stage.outs = output.loads_from(stage, outs, use_cache=True)
427        stage.outs += output.loads_from(
428            stage, metrics, use_cache=True, metric=True
429        )
430        stage.outs += output.loads_from(stage, outs_no_cache, use_cache=False)
431        stage.outs += output.loads_from(
432            stage, metrics_no_cache, use_cache=False, metric=True
433        )
434        stage.deps = dependency.loads_from(stage, deps)
435
436        stage._check_circular_dependency()
437        stage._check_duplicated_arguments()
438
439        fname = Stage._stage_fname(fname, stage.outs, add=add)
440        wdir = os.path.abspath(wdir)
441
442        if cwd is not None:
443            path = os.path.join(wdir, fname)
444        else:
445            path = os.path.abspath(fname)
446
447        Stage._check_stage_path(repo, wdir)
448        Stage._check_stage_path(repo, os.path.dirname(path))
449
450        stage.wdir = wdir
451        stage.path = path
452
453        # NOTE: remove outs before we check build cache
454        if remove_outs:
455            stage.remove_outs(ignore_remove=False)
456            logger.warning("Build cache is ignored when using --remove-outs.")
457            ignore_build_cache = True
458        else:
459            stage.unprotect_outs()
460
461        if os.path.exists(path):
462            if not ignore_build_cache and stage.is_cached:
463                logger.info("Stage is cached, skipping.")
464                return None
465
466            msg = (
467                "'{}' already exists. Do you wish to run the command and "
468                "overwrite it?".format(stage.relpath)
469            )
470
471            if not overwrite and not prompt.confirm(msg):
472                raise StageFileAlreadyExistsError(stage.relpath)
473
474            os.unlink(path)
475
476        return stage
477
478    @staticmethod
479    def _check_dvc_filename(fname):
480        if not Stage.is_valid_filename(fname):
481            raise StageFileBadNameError(
482                "bad stage filename '{}'. Stage files should be named"
483                " 'Dvcfile' or have a '.dvc' suffix (e.g. '{}.dvc').".format(
484                    os.path.relpath(fname), os.path.basename(fname)
485                )
486            )
487
488    @staticmethod
489    def _check_file_exists(fname):
490        if not os.path.exists(fname):
491            raise StageFileDoesNotExistError(fname)
492
493    @staticmethod
494    def load(repo, fname):
495        Stage._check_file_exists(fname)
496        Stage._check_dvc_filename(fname)
497
498        if not Stage.is_stage_file(fname):
499            raise StageFileIsNotDvcFileError(fname)
500
501        d = load_stage_file(fname)
502
503        Stage.validate(d, fname=os.path.relpath(fname))
504        path = os.path.abspath(fname)
505
506        stage = Stage(
507            repo=repo,
508            path=path,
509            wdir=os.path.abspath(
510                os.path.join(
511                    os.path.dirname(path), d.get(Stage.PARAM_WDIR, ".")
512                )
513            ),
514            cmd=d.get(Stage.PARAM_CMD),
515            md5=d.get(Stage.PARAM_MD5),
516            locked=d.get(Stage.PARAM_LOCKED, False),
517        )
518
519        stage.deps = dependency.loadd_from(stage, d.get(Stage.PARAM_DEPS, []))
520        stage.outs = output.loadd_from(stage, d.get(Stage.PARAM_OUTS, []))
521
522        return stage
523
524    def dumpd(self):
525        return {
526            key: value
527            for key, value in {
528                Stage.PARAM_MD5: self.md5,
529                Stage.PARAM_CMD: self.cmd,
530                Stage.PARAM_WDIR: os.path.relpath(
531                    self.wdir, os.path.dirname(self.path)
532                ),
533                Stage.PARAM_LOCKED: self.locked,
534                Stage.PARAM_DEPS: [d.dumpd() for d in self.deps],
535                Stage.PARAM_OUTS: [o.dumpd() for o in self.outs],
536            }.items()
537            if value
538        }
539
540    def dump(self):
541        fname = self.path
542
543        self._check_dvc_filename(fname)
544
545        logger.info(
546            "Saving information to '{file}'.".format(
547                file=os.path.relpath(fname)
548            )
549        )
550        d = self.dumpd()
551
552        with open(fname, "w") as fd:
553            yaml.safe_dump(d, fd, default_flow_style=False)
554
555        self.repo.files_to_git_add.append(os.path.relpath(fname))
556
557    def _compute_md5(self):
558        from dvc.output.local import OutputLOCAL
559
560        d = self.dumpd()
561
562        # NOTE: removing md5 manually in order to not affect md5s in deps/outs
563        if self.PARAM_MD5 in d.keys():
564            del d[self.PARAM_MD5]
565
566        # Ignore the wdir default value. In this case stage file w/o
567        # wdir has the same md5 as a file with the default value specified.
568        # It's important for backward compatibility with pipelines that
569        # didn't have WDIR in their stage files.
570        if d.get(self.PARAM_WDIR) == ".":
571            del d[self.PARAM_WDIR]
572
573        # NOTE: excluding parameters that don't affect the state of the
574        # pipeline. Not excluding `OutputLOCAL.PARAM_CACHE`, because if
575        # it has changed, we might not have that output in our cache.
576        m = dict_md5(d, exclude=[self.PARAM_LOCKED, OutputLOCAL.PARAM_METRIC])
577        logger.debug("Computed stage '{}' md5: '{}'".format(self.relpath, m))
578        return m
579
580    def save(self):
581        for dep in self.deps:
582            dep.save()
583
584        for out in self.outs:
585            out.save()
586
587        self.md5 = self._compute_md5()
588
589    @staticmethod
590    def _changed_entries(entries):
591        ret = []
592        for entry in entries:
593            if entry.checksum and entry.changed_checksum():
594                ret.append(entry.rel_path)
595        return ret
596
597    def check_can_commit(self, force):
598        changed_deps = self._changed_entries(self.deps)
599        changed_outs = self._changed_entries(self.outs)
600
601        if changed_deps or changed_outs or self.changed_md5():
602            msg = (
603                "dependencies {}".format(changed_deps) if changed_deps else ""
604            )
605            msg += " and " if (changed_deps and changed_outs) else ""
606            msg += "outputs {}".format(changed_outs) if changed_outs else ""
607            msg += "md5" if not (changed_deps or changed_outs) else ""
608            msg += " of '{}' changed. Are you sure you commit it?".format(
609                self.relpath
610            )
611            if not force and not prompt.confirm(msg):
612                raise StageCommitError(
613                    "unable to commit changed '{}'. Use `-f|--force` to "
614                    "force.`".format(self.relpath)
615                )
616            self.save()
617
618    def commit(self):
619        for out in self.outs:
620            out.commit()
621
622    def _check_missing_deps(self):
623        missing = [dep for dep in self.deps if not dep.exists]
624
625        if any(missing):
626            raise MissingDep(missing)
627
628    @staticmethod
629    def _warn_if_fish(executable):  # pragma: no cover
630        if (
631            executable is None
632            or os.path.basename(os.path.realpath(executable)) != "fish"
633        ):
634            return
635
636        logger.warning(
637            "DVC detected that you are using fish as your default "
638            "shell. Be aware that it might cause problems by overwriting "
639            "your current environment variables with values defined "
640            "in '.fishrc', which might affect your command. See "
641            "https://github.com/iterative/dvc/issues/1307. "
642        )
643
644    def _check_circular_dependency(self):
645        from dvc.exceptions import CircularDependencyError
646
647        circular_dependencies = set(d.path for d in self.deps) & set(
648            o.path for o in self.outs
649        )
650
651        if circular_dependencies:
652            raise CircularDependencyError(circular_dependencies.pop())
653
654    def _check_duplicated_arguments(self):
655        from dvc.exceptions import ArgumentDuplicationError
656        from collections import Counter
657
658        path_counts = Counter(edge.path for edge in self.deps + self.outs)
659
660        for path, occurrence in path_counts.items():
661            if occurrence > 1:
662                raise ArgumentDuplicationError(path)
663
664    def _run(self):
665        self._check_missing_deps()
666        executable = os.getenv("SHELL") if os.name != "nt" else None
667        self._warn_if_fish(executable)
668
669        p = subprocess.Popen(
670            self.cmd,
671            cwd=self.wdir,
672            shell=True,
673            env=fix_env(os.environ),
674            executable=executable,
675        )
676        p.communicate()
677
678        if p.returncode != 0:
679            raise StageCmdFailedError(self)
680
681    def run(self, dry=False, resume=False, no_commit=False):
682        if self.locked:
683            logger.info(
684                "Verifying outputs in locked stage '{stage}'".format(
685                    stage=self.relpath
686                )
687            )
688            if not dry:
689                self.check_missing_outputs()
690
691        elif self.is_import:
692            logger.info(
693                "Importing '{dep}' -> '{out}'".format(
694                    dep=self.deps[0].path, out=self.outs[0].path
695                )
696            )
697            if not dry:
698                if self._already_cached():
699                    self.outs[0].checkout()
700                else:
701                    self.deps[0].download(
702                        self.outs[0].path_info, resume=resume
703                    )
704
705        elif self.is_data_source:
706            msg = "Verifying data sources in '{}'".format(self.relpath)
707            logger.info(msg)
708            if not dry:
709                self.check_missing_outputs()
710
711        else:
712            logger.info("Running command:\n\t{}".format(self.cmd))
713            if not dry:
714                if not self.is_callback and self._already_cached():
715                    self.checkout()
716                else:
717                    self._run()
718
719        if not dry:
720            self.save()
721            if not no_commit:
722                self.commit()
723
724    def check_missing_outputs(self):
725        paths = [
726            out.path if out.scheme != "local" else out.rel_path
727            for out in self.outs
728            if not out.exists
729        ]
730
731        if paths:
732            raise MissingDataSource(paths)
733
734    def checkout(self, force=False):
735        for out in self.outs:
736            out.checkout(force=force)
737
738    @staticmethod
739    def _status(entries):
740        ret = {}
741
742        for entry in entries:
743            ret.update(entry.status())
744
745        return ret
746
747    def status(self):
748        ret = []
749
750        if not self.locked:
751            deps_status = self._status(self.deps)
752            if deps_status:
753                ret.append({"changed deps": deps_status})
754
755        outs_status = self._status(self.outs)
756        if outs_status:
757            ret.append({"changed outs": outs_status})
758
759        if self.changed_md5():
760            ret.append("changed checksum")
761
762        if self.is_callback:
763            ret.append("always changed")
764
765        if ret:
766            return {self.relpath: ret}
767
768        return {}
769
770    def _already_cached(self):
771        return (
772            not self.changed_md5()
773            and all(not dep.changed() for dep in self.deps)
774            and all(
775                not out.changed_cache() if out.use_cache else not out.changed()
776                for out in self.outs
777            )
778        )
779