1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18""""Defines abstractions around compiler artifacts produced in compiling micro TVM binaries.""" 19 20import io 21import os 22import json 23import shutil 24import tarfile 25 26 27class ArtifactFileNotFoundError(Exception): 28 """Raised when an artifact file cannot be found on disk.""" 29 30 31class ArtifactBadSymlinkError(Exception): 32 """Raised when an artifact symlink points outside the base directory.""" 33 34 35class ArtifactBadArchiveError(Exception): 36 """Raised when an artifact archive is malformed.""" 37 38 39class Artifact: 40 """Describes a compiler artifact and defines common logic to archive it for transport.""" 41 42 # A version number written to the archive. 43 ENCODING_VERSION = 1 44 45 # A unique string identifying the type of artifact in an archive. Subclasses must redefine this 46 # variable. 47 ARTIFACT_TYPE = None 48 49 @classmethod 50 def unarchive(cls, archive_path, base_dir): 51 """Unarchive an artifact into base_dir. 52 53 Parameters 54 ---------- 55 archive_path : str 56 Path to the archive file. 57 base_dir : str 58 Path to a non-existent, empty directory under which the artifact will live. 59 60 Returns 61 ------- 62 Artifact : 63 The unarchived artifact. 64 """ 65 if os.path.exists(base_dir): 66 raise ValueError(f"base_dir exists: {base_dir}") 67 68 base_dir_parent, base_dir_name = os.path.split(base_dir) 69 temp_dir = os.path.join(base_dir_parent, f"__tvm__{base_dir_name}") 70 os.mkdir(temp_dir) 71 try: 72 with tarfile.open(archive_path) as tar_f: 73 tar_f.extractall(temp_dir) 74 75 temp_dir_contents = os.listdir(temp_dir) 76 if len(temp_dir_contents) != 1: 77 raise ArtifactBadArchiveError( 78 "Expected exactly 1 subdirectory at root of archive, got " 79 f"{temp_dir_contents!r}" 80 ) 81 82 metadata_path = os.path.join(temp_dir, temp_dir_contents[0], "metadata.json") 83 if not metadata_path: 84 raise ArtifactBadArchiveError("No metadata.json found in archive") 85 86 with open(metadata_path) as metadata_f: 87 metadata = json.load(metadata_f) 88 89 version = metadata.get("version") 90 if version != cls.ENCODING_VERSION: 91 raise ArtifactBadArchiveError( 92 f"archive version: expect {cls.EXPECTED_VERSION}, found {version}" 93 ) 94 95 os.rename(os.path.join(temp_dir, temp_dir_contents[0]), base_dir) 96 97 artifact_cls = cls 98 for sub_cls in cls.__subclasses__(): 99 if sub_cls.ARTIFACT_TYPE is not None and sub_cls.ARTIFACT_TYPE == metadata.get( 100 "artifact_type" 101 ): 102 artifact_cls = sub_cls 103 break 104 105 return artifact_cls.from_unarchived( 106 base_dir, metadata["labelled_files"], metadata["metadata"] 107 ) 108 finally: 109 shutil.rmtree(temp_dir) 110 111 @classmethod 112 def from_unarchived(cls, base_dir, labelled_files, metadata): 113 return cls(base_dir, labelled_files, metadata) 114 115 def __init__(self, base_dir, labelled_files, metadata): 116 """Create a new artifact. 117 118 Parameters 119 ---------- 120 base_dir : str 121 The path to a directory on disk which contains all the files in this artifact. 122 labelled_files : Dict[str, str] 123 A dict mapping a file label to the relative paths of the files that carry that label. 124 metadata : Dict 125 A dict containing artitrary JSON-serializable key-value data describing the artifact. 126 """ 127 self.base_dir = os.path.realpath(base_dir) 128 self.labelled_files = labelled_files 129 self.metadata = metadata 130 131 for label, files in labelled_files.items(): 132 for f in files: 133 f_path = os.path.join(self.base_dir, f) 134 if not os.path.lexists(f_path): 135 raise ArtifactFileNotFoundError(f"{f} (label {label}): not found at {f_path}") 136 137 if os.path.islink(f_path): 138 link_path = os.path.readlink(f_path) 139 if os.path.isabs(link_path): 140 link_fullpath = link_path 141 else: 142 link_fullpath = os.path.join(os.path.dirname(f_path), link_path) 143 144 link_fullpath = os.path.realpath(link_fullpath) 145 if not link_fullpath.startswith(self.base_dir): 146 raise ArtifactBadSymlinkError( 147 f"{f} (label {label}): symlink points outside artifact tree" 148 ) 149 150 def abspath(self, rel_path): 151 """Return absolute path to the member with the given relative path.""" 152 return os.path.join(self.base_dir, rel_path) 153 154 def label(self, label): 155 """Return a list of relative paths to files with the given label.""" 156 return self.labelled_files[label] 157 158 def label_abspath(self, label): 159 return [self.abspath(p) for p in self.labelled_files[label]] 160 161 def archive(self, archive_path): 162 """Create a relocatable tar archive of the artifacts. 163 164 Parameters 165 ---------- 166 archive_path : str 167 Path to the tar file to create. Or, path to a directory, under which a tar file will be 168 created named {base_dir}.tar. 169 170 Returns 171 ------- 172 str : 173 The value of archive_path, after potentially making the computation describe above. 174 """ 175 if os.path.isdir(archive_path): 176 archive_path = os.path.join(archive_path, f"{os.path.basename(self.base_dir)}.tar") 177 178 archive_name = os.path.splitext(os.path.basename(archive_path))[0] 179 with tarfile.open(archive_path, "w") as tar_f: 180 181 def _add_file(name, data, f_type): 182 tar_info = tarfile.TarInfo(name=name) 183 tar_info.type = f_type 184 data_bytes = bytes(data, "utf-8") 185 tar_info.size = len(data) 186 tar_f.addfile(tar_info, io.BytesIO(data_bytes)) 187 188 _add_file( 189 f"{archive_name}/metadata.json", 190 json.dumps( 191 { 192 "version": self.ENCODING_VERSION, 193 "labelled_files": self.labelled_files, 194 "metadata": self.metadata, 195 }, 196 indent=2, 197 sort_keys=True, 198 ), 199 tarfile.REGTYPE, 200 ) 201 for dir_path, _, files in os.walk(self.base_dir): 202 for f in files: 203 file_path = os.path.join(dir_path, f) 204 archive_file_path = os.path.join( 205 archive_name, os.path.relpath(file_path, self.base_dir) 206 ) 207 if not os.path.islink(file_path): 208 tar_f.add(file_path, archive_file_path, recursive=False) 209 continue 210 211 link_path = os.readlink(file_path) 212 if not os.path.isabs(link_path): 213 tar_f.add(file_path, archive_file_path, recursive=False) 214 continue 215 216 relpath = os.path.relpath(link_path, os.path.dirname(file_path)) 217 _add_file(archive_file_path, relpath, tarfile.LNKTYPE) 218 219 return archive_path 220