1import scipy.sparse as sparse 2import numpy as np 3from typing import * 4from loompy import timestamp 5 6 7def _renumber(a: np.ndarray, keys: np.ndarray, values: np.ndarray) -> np.ndarray: 8 """ 9 Renumber 'a' by replacing any occurrence of 'keys' by the corresponding 'values' 10 """ 11 ordering = np.argsort(keys) 12 keys = keys[ordering] 13 values = keys[ordering] 14 index = np.digitize(a.ravel(), keys, right=True) 15 return(values[index].reshape(a.shape)) 16 17 18class GraphManager: 19 """ 20 Manage a set of graphs (either for rows or columns) with a backing HDF5 file store 21 """ 22 def __init__(self, ds: Any, *, axis: int) -> None: 23 setattr(self, "!axis", axis) 24 setattr(self, "!ds", ds) 25 storage: Dict[str, np.ndarray] = {} 26 setattr(self, "!storage", storage) 27 28 if ds is not None: 29 # Patch old files that use the old naming convention 30 if ds._file.mode == "r+": 31 if "row_graphs" not in ds._file: 32 ds._file.create_group('/row_graphs') 33 if "col_graphs" not in ds._file: 34 ds._file.create_group('/col_graphs') 35 if "row_edges" in ds._file: 36 for key in ds._file["row_edges"]: 37 ds._file["row_graphs"][key] = ds._file["row_edges"][key] 38 del ds._file["row_edges"] 39 if "col_edges" in ds._file: 40 for key in ds._file["col_edges"]: 41 ds._file["col_graphs"][key] = ds._file["col_edges"][key] 42 del ds._file["col_edges"] 43 44 a = ["row_graphs", "col_graphs"][self.axis] 45 if a in ds._file: 46 for key in ds._file[a]: 47 self.__dict__["storage"][key] = None 48 else: 49 if ds.mode == "r+": 50 ds._file.create_group(a) 51 52 def keys(self) -> List[str]: 53 return list(self.__dict__["storage"].keys()) 54 55 def items(self) -> Iterable[Tuple[str, sparse.coo_matrix]]: 56 for key in self.keys(): 57 yield (key, self[key]) 58 59 def __len__(self) -> int: 60 return len(self.keys()) 61 62 def __contains__(self, name: str) -> bool: 63 return name in self.keys() 64 65 def __iter__(self) -> Iterator[str]: 66 for key in self.keys(): 67 yield key 68 69 def last_modified(self, name: str = None) -> str: 70 """ 71 Return a compact ISO8601 timestamp (UTC timezone) indicating when a graph was last modified 72 73 Note: if no graph name is given (the default), the modification time of the most recently modified graph will be returned 74 Note: if the graphs do not contain a timestamp, and the mode is 'r+', a new timestamp is created and returned. 75 Otherwise, the current time in UTC will be returned. 76 """ 77 a = ["row_graphs", "col_graphs"][self.axis] 78 79 if name is None: 80 if "last_modified" in self.ds._file[a].attrs: 81 return self.ds._file[a].attrs["last_modified"] 82 elif self.ds._file.mode == 'r+': 83 self.ds._file[a].attrs["last_modified"] = timestamp() 84 self.ds._file.flush() 85 return self.ds._file[a].attrs["last_modified"] 86 if name is not None: 87 if "last_modified" in self.ds._file[a + name].attrs: 88 return self.ds._file[a][name].attrs["last_modified"] 89 elif self.ds._file.mode == 'r+': 90 self.ds._file[a][name].attrs["last_modified"] = timestamp() 91 self.ds._file.flush() 92 return self.ds._file[a][name].attrs["last_modified"] 93 return timestamp() 94 95 def __getitem__(self, thing: Any) -> sparse.coo_matrix: 96 if type(thing) is slice or type(thing) is np.ndarray or type(thing) is int: 97 gm = GraphManager(None, axis=self.axis) 98 for key, g in self.items(): 99 # Slice the graph matrix properly without making it dense 100 (a, b, w) = (g.row, g.col, g.data) 101 indices = np.arange(g.shape[0])[thing] 102 mask = np.logical_and(np.in1d(a, indices), np.in1d(b, indices)) 103 a = a[mask] 104 b = b[mask] 105 w = w[mask] 106 d = dict(zip(np.sort(indices), np.arange(indices.shape[0]))) 107 a = np.array([d[x] for x in a]) 108 b = np.array([d[x] for x in b]) 109 gm[key] = sparse.coo_matrix((w, (a, b)), shape=(len(indices), len(indices))) 110 return gm 111 elif type(thing) is tuple: 112 # A tuple of strings giving alternative names for graphs 113 for t in thing: 114 if t in self.__dict__["storage"]: 115 return self.__getattr__(t) 116 raise AttributeError(f"'{type(self)}' object has no attribute {thing}") 117 else: 118 return self.__getattr__(thing) 119 120 def __getattr__(self, name: str) -> sparse.coo_matrix: 121 try: 122 g = self.__dict__["storage"][name] 123 if g is None: 124 # Read values from the HDF5 file 125 a = ["row_graphs", "col_graphs"][self.axis] 126 r = self.ds._file[a][name]["a"] 127 c = self.ds._file[a][name]["b"] 128 w = self.ds._file[a][name]["w"] 129 g = sparse.coo_matrix((w, (r, c)), shape=(self.ds.shape[self.axis], self.ds.shape[self.axis])) 130 self.__dict__["storage"][name] = g 131 return g 132 except KeyError: 133 raise AttributeError(f"'{type(self)}' object has no graph '{name}' on axis {self.axis}") 134 135 def __setitem__(self, name: str, g: sparse.coo_matrix) -> None: 136 return self.__setattr__(name, g) 137 138 def __setattr__(self, name: str, g: sparse.coo_matrix) -> None: 139 if name.startswith("!"): 140 super(GraphManager, self).__setattr__(name[1:], g) 141 elif "/" in name: 142 raise KeyError("Graph name cannot contain slash (/)") 143 else: 144 g = sparse.coo_matrix(g) 145 if self.ds is not None: 146 a = ["row_graphs", "col_graphs"][self.axis] 147 if g.shape[0] != self.ds.shape[self.axis] or g.shape[1] != self.ds.shape[self.axis]: 148 raise ValueError(f"Adjacency matrix shape for axis {self.axis} must be ({self.ds.shape[self.axis]},{self.ds.shape[self.axis]}) but shape was {g.shape}") 149 if name in self.ds._file[a]: 150 del self.ds._file[a][name]["a"] 151 del self.ds._file[a][name]["b"] 152 del self.ds._file[a][name]["w"] 153 del self.ds._file[a][name] 154 self.ds._file[a].create_group(name) 155 self.ds._file[a][name]["a"] = g.row 156 self.ds._file[a][name]["b"] = g.col 157 self.ds._file[a][name]["w"] = g.data 158 self.ds._file[a][name].attrs["last_modified"] = timestamp() 159 self.ds._file[a].attrs["last_modified"] = timestamp() 160 self.ds._file.attrs["last_modified"] = timestamp() 161 self.ds._file.flush() 162 self.__dict__["storage"][name] = g 163 else: 164 self.__dict__["storage"][name] = g 165 166 def __delitem__(self, name: str) -> None: 167 return self.__delattr__(name) 168 169 def __delattr__(self, name: str) -> None: 170 if self.ds is not None: 171 a = ["row_graphs", "col_graphs"][self.axis] 172 if self.ds._file[a].__contains__(name): 173 del self.ds._file[a][name]["a"] 174 del self.ds._file[a][name]["b"] 175 del self.ds._file[a][name]["w"] 176 del self.ds._file[a][name] 177 self.ds._file.flush() 178 if name in self.__dict__["storage"]: 179 del self.__dict__["storage"][name] 180 181 def _permute(self, ordering: np.ndarray) -> None: 182 for name in self.keys(): 183 g = self[name] 184 (a, b, w) = (g.row, g.col, g.data) 185 a = _renumber(a, np.array(ordering), np.arange(g.shape[1])) 186 b = _renumber(b, np.array(ordering), np.arange(g.shape[1])) 187 g = sparse.coo_matrix((w, (a, b)), g.shape) 188 self[name] = g 189