1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18""""Defines abstractions around compiler artifacts produced in compiling micro TVM binaries."""
19
20import io
21import os
22import json
23import shutil
24import tarfile
25
26
27class ArtifactFileNotFoundError(Exception):
28    """Raised when an artifact file cannot be found on disk."""
29
30
31class ArtifactBadSymlinkError(Exception):
32    """Raised when an artifact symlink points outside the base directory."""
33
34
35class ArtifactBadArchiveError(Exception):
36    """Raised when an artifact archive is malformed."""
37
38
39class Artifact:
40    """Describes a compiler artifact and defines common logic to archive it for transport."""
41
42    # A version number written to the archive.
43    ENCODING_VERSION = 1
44
45    # A unique string identifying the type of artifact in an archive. Subclasses must redefine this
46    # variable.
47    ARTIFACT_TYPE = None
48
49    @classmethod
50    def unarchive(cls, archive_path, base_dir):
51        """Unarchive an artifact into base_dir.
52
53        Parameters
54        ----------
55        archive_path : str
56            Path to the archive file.
57        base_dir : str
58            Path to a non-existent, empty directory under which the artifact will live.
59
60        Returns
61        -------
62        Artifact :
63            The unarchived artifact.
64        """
65        if os.path.exists(base_dir):
66            raise ValueError(f"base_dir exists: {base_dir}")
67
68        base_dir_parent, base_dir_name = os.path.split(base_dir)
69        temp_dir = os.path.join(base_dir_parent, f"__tvm__{base_dir_name}")
70        os.mkdir(temp_dir)
71        try:
72            with tarfile.open(archive_path) as tar_f:
73                tar_f.extractall(temp_dir)
74
75                temp_dir_contents = os.listdir(temp_dir)
76                if len(temp_dir_contents) != 1:
77                    raise ArtifactBadArchiveError(
78                        "Expected exactly 1 subdirectory at root of archive, got "
79                        f"{temp_dir_contents!r}"
80                    )
81
82                metadata_path = os.path.join(temp_dir, temp_dir_contents[0], "metadata.json")
83                if not metadata_path:
84                    raise ArtifactBadArchiveError("No metadata.json found in archive")
85
86                with open(metadata_path) as metadata_f:
87                    metadata = json.load(metadata_f)
88
89                version = metadata.get("version")
90                if version != cls.ENCODING_VERSION:
91                    raise ArtifactBadArchiveError(
92                        f"archive version: expect {cls.EXPECTED_VERSION}, found {version}"
93                    )
94
95                os.rename(os.path.join(temp_dir, temp_dir_contents[0]), base_dir)
96
97                artifact_cls = cls
98                for sub_cls in cls.__subclasses__():
99                    if sub_cls.ARTIFACT_TYPE is not None and sub_cls.ARTIFACT_TYPE == metadata.get(
100                        "artifact_type"
101                    ):
102                        artifact_cls = sub_cls
103                        break
104
105                return artifact_cls.from_unarchived(
106                    base_dir, metadata["labelled_files"], metadata["metadata"]
107                )
108        finally:
109            shutil.rmtree(temp_dir)
110
111    @classmethod
112    def from_unarchived(cls, base_dir, labelled_files, metadata):
113        return cls(base_dir, labelled_files, metadata)
114
115    def __init__(self, base_dir, labelled_files, metadata):
116        """Create a new artifact.
117
118        Parameters
119        ----------
120        base_dir : str
121            The path to a directory on disk which contains all the files in this artifact.
122        labelled_files : Dict[str, str]
123            A dict mapping a file label to the relative paths of the files that carry that label.
124        metadata : Dict
125            A dict containing artitrary JSON-serializable key-value data describing the artifact.
126        """
127        self.base_dir = os.path.realpath(base_dir)
128        self.labelled_files = labelled_files
129        self.metadata = metadata
130
131        for label, files in labelled_files.items():
132            for f in files:
133                f_path = os.path.join(self.base_dir, f)
134                if not os.path.lexists(f_path):
135                    raise ArtifactFileNotFoundError(f"{f} (label {label}): not found at {f_path}")
136
137                if os.path.islink(f_path):
138                    link_path = os.path.readlink(f_path)
139                    if os.path.isabs(link_path):
140                        link_fullpath = link_path
141                    else:
142                        link_fullpath = os.path.join(os.path.dirname(f_path), link_path)
143
144                    link_fullpath = os.path.realpath(link_fullpath)
145                    if not link_fullpath.startswith(self.base_dir):
146                        raise ArtifactBadSymlinkError(
147                            f"{f} (label {label}): symlink points outside artifact tree"
148                        )
149
150    def abspath(self, rel_path):
151        """Return absolute path to the member with the given relative path."""
152        return os.path.join(self.base_dir, rel_path)
153
154    def label(self, label):
155        """Return a list of relative paths to files with the given label."""
156        return self.labelled_files[label]
157
158    def label_abspath(self, label):
159        return [self.abspath(p) for p in self.labelled_files[label]]
160
161    def archive(self, archive_path):
162        """Create a relocatable tar archive of the artifacts.
163
164        Parameters
165        ----------
166        archive_path : str
167            Path to the tar file to create. Or, path to a directory, under which a tar file will be
168            created named {base_dir}.tar.
169
170        Returns
171        -------
172        str :
173            The value of archive_path, after potentially making the computation describe above.
174        """
175        if os.path.isdir(archive_path):
176            archive_path = os.path.join(archive_path, f"{os.path.basename(self.base_dir)}.tar")
177
178        archive_name = os.path.splitext(os.path.basename(archive_path))[0]
179        with tarfile.open(archive_path, "w") as tar_f:
180
181            def _add_file(name, data, f_type):
182                tar_info = tarfile.TarInfo(name=name)
183                tar_info.type = f_type
184                data_bytes = bytes(data, "utf-8")
185                tar_info.size = len(data)
186                tar_f.addfile(tar_info, io.BytesIO(data_bytes))
187
188            _add_file(
189                f"{archive_name}/metadata.json",
190                json.dumps(
191                    {
192                        "version": self.ENCODING_VERSION,
193                        "labelled_files": self.labelled_files,
194                        "metadata": self.metadata,
195                    },
196                    indent=2,
197                    sort_keys=True,
198                ),
199                tarfile.REGTYPE,
200            )
201            for dir_path, _, files in os.walk(self.base_dir):
202                for f in files:
203                    file_path = os.path.join(dir_path, f)
204                    archive_file_path = os.path.join(
205                        archive_name, os.path.relpath(file_path, self.base_dir)
206                    )
207                    if not os.path.islink(file_path):
208                        tar_f.add(file_path, archive_file_path, recursive=False)
209                        continue
210
211                    link_path = os.readlink(file_path)
212                    if not os.path.isabs(link_path):
213                        tar_f.add(file_path, archive_file_path, recursive=False)
214                        continue
215
216                    relpath = os.path.relpath(link_path, os.path.dirname(file_path))
217                    _add_file(archive_file_path, relpath, tarfile.LNKTYPE)
218
219        return archive_path
220