1import scipy.sparse as sparse
2import numpy as np
3from typing import *
4from loompy import timestamp
5
6
7def _renumber(a: np.ndarray, keys: np.ndarray, values: np.ndarray) -> np.ndarray:
8	"""
9	Renumber 'a' by replacing any occurrence of 'keys' by the corresponding 'values'
10	"""
11	ordering = np.argsort(keys)
12	keys = keys[ordering]
13	values = keys[ordering]
14	index = np.digitize(a.ravel(), keys, right=True)
15	return(values[index].reshape(a.shape))
16
17
18class GraphManager:
19	"""
20	Manage a set of graphs (either for rows or columns) with a backing HDF5 file store
21	"""
22	def __init__(self, ds: Any, *, axis: int) -> None:
23		setattr(self, "!axis", axis)
24		setattr(self, "!ds", ds)
25		storage: Dict[str, np.ndarray] = {}
26		setattr(self, "!storage", storage)
27
28		if ds is not None:
29			# Patch old files that use the old naming convention
30			if ds._file.mode == "r+":
31				if "row_graphs" not in ds._file:
32					ds._file.create_group('/row_graphs')
33				if "col_graphs" not in ds._file:
34					ds._file.create_group('/col_graphs')
35				if "row_edges" in ds._file:
36					for key in ds._file["row_edges"]:
37						ds._file["row_graphs"][key] = ds._file["row_edges"][key]
38					del ds._file["row_edges"]
39				if "col_edges" in ds._file:
40					for key in ds._file["col_edges"]:
41						ds._file["col_graphs"][key] = ds._file["col_edges"][key]
42					del ds._file["col_edges"]
43
44			a = ["row_graphs", "col_graphs"][self.axis]
45			if a in ds._file:
46				for key in ds._file[a]:
47					self.__dict__["storage"][key] = None
48			else:
49				if ds.mode == "r+":
50					ds._file.create_group(a)
51
52	def keys(self) -> List[str]:
53		return list(self.__dict__["storage"].keys())
54
55	def items(self) -> Iterable[Tuple[str, sparse.coo_matrix]]:
56		for key in self.keys():
57			yield (key, self[key])
58
59	def __len__(self) -> int:
60		return len(self.keys())
61
62	def __contains__(self, name: str) -> bool:
63		return name in self.keys()
64
65	def __iter__(self) -> Iterator[str]:
66		for key in self.keys():
67			yield key
68
69	def last_modified(self, name: str = None) -> str:
70		"""
71		Return a compact ISO8601 timestamp (UTC timezone) indicating when a graph was last modified
72
73		Note: if no graph name is given (the default), the modification time of the most recently modified graph will be returned
74		Note: if the graphs do not contain a timestamp, and the mode is 'r+', a new timestamp is created and returned.
75		Otherwise, the current time in UTC will be returned.
76		"""
77		a = ["row_graphs", "col_graphs"][self.axis]
78
79		if name is None:
80			if "last_modified" in self.ds._file[a].attrs:
81				return self.ds._file[a].attrs["last_modified"]
82			elif self.ds._file.mode == 'r+':
83				self.ds._file[a].attrs["last_modified"] = timestamp()
84				self.ds._file.flush()
85				return self.ds._file[a].attrs["last_modified"]
86		if name is not None:
87			if "last_modified" in self.ds._file[a + name].attrs:
88				return self.ds._file[a][name].attrs["last_modified"]
89			elif self.ds._file.mode == 'r+':
90				self.ds._file[a][name].attrs["last_modified"] = timestamp()
91				self.ds._file.flush()
92				return self.ds._file[a][name].attrs["last_modified"]
93		return timestamp()
94
95	def __getitem__(self, thing: Any) -> sparse.coo_matrix:
96		if type(thing) is slice or type(thing) is np.ndarray or type(thing) is int:
97			gm = GraphManager(None, axis=self.axis)
98			for key, g in self.items():
99				# Slice the graph matrix properly without making it dense
100				(a, b, w) = (g.row, g.col, g.data)
101				indices = np.arange(g.shape[0])[thing]
102				mask = np.logical_and(np.in1d(a, indices), np.in1d(b, indices))
103				a = a[mask]
104				b = b[mask]
105				w = w[mask]
106				d = dict(zip(np.sort(indices), np.arange(indices.shape[0])))
107				a = np.array([d[x] for x in a])
108				b = np.array([d[x] for x in b])
109				gm[key] = sparse.coo_matrix((w, (a, b)), shape=(len(indices), len(indices)))
110			return gm
111		elif type(thing) is tuple:
112			# A tuple of strings giving alternative names for graphs
113			for t in thing:
114				if t in self.__dict__["storage"]:
115					return self.__getattr__(t)
116			raise AttributeError(f"'{type(self)}' object has no attribute {thing}")
117		else:
118			return self.__getattr__(thing)
119
120	def __getattr__(self, name: str) -> sparse.coo_matrix:
121		try:
122			g = self.__dict__["storage"][name]
123			if g is None:
124				# Read values from the HDF5 file
125				a = ["row_graphs", "col_graphs"][self.axis]
126				r = self.ds._file[a][name]["a"]
127				c = self.ds._file[a][name]["b"]
128				w = self.ds._file[a][name]["w"]
129				g = sparse.coo_matrix((w, (r, c)), shape=(self.ds.shape[self.axis], self.ds.shape[self.axis]))
130				self.__dict__["storage"][name] = g
131			return g
132		except KeyError:
133			raise AttributeError(f"'{type(self)}' object has no graph '{name}' on axis {self.axis}")
134
135	def __setitem__(self, name: str, g: sparse.coo_matrix) -> None:
136		return self.__setattr__(name, g)
137
138	def __setattr__(self, name: str, g: sparse.coo_matrix) -> None:
139		if name.startswith("!"):
140			super(GraphManager, self).__setattr__(name[1:], g)
141		elif "/" in name:
142			raise KeyError("Graph name cannot contain slash (/)")
143		else:
144			g = sparse.coo_matrix(g)
145			if self.ds is not None:
146				a = ["row_graphs", "col_graphs"][self.axis]
147				if g.shape[0] != self.ds.shape[self.axis] or g.shape[1] != self.ds.shape[self.axis]:
148					raise ValueError(f"Adjacency matrix shape for axis {self.axis} must be ({self.ds.shape[self.axis]},{self.ds.shape[self.axis]}) but shape was {g.shape}")
149				if name in self.ds._file[a]:
150					del self.ds._file[a][name]["a"]
151					del self.ds._file[a][name]["b"]
152					del self.ds._file[a][name]["w"]
153					del self.ds._file[a][name]
154				self.ds._file[a].create_group(name)
155				self.ds._file[a][name]["a"] = g.row
156				self.ds._file[a][name]["b"] = g.col
157				self.ds._file[a][name]["w"] = g.data
158				self.ds._file[a][name].attrs["last_modified"] = timestamp()
159				self.ds._file[a].attrs["last_modified"] = timestamp()
160				self.ds._file.attrs["last_modified"] = timestamp()
161				self.ds._file.flush()
162				self.__dict__["storage"][name] = g
163			else:
164				self.__dict__["storage"][name] = g
165
166	def __delitem__(self, name: str) -> None:
167		return self.__delattr__(name)
168
169	def __delattr__(self, name: str) -> None:
170		if self.ds is not None:
171			a = ["row_graphs", "col_graphs"][self.axis]
172			if self.ds._file[a].__contains__(name):
173				del self.ds._file[a][name]["a"]
174				del self.ds._file[a][name]["b"]
175				del self.ds._file[a][name]["w"]
176				del self.ds._file[a][name]
177				self.ds._file.flush()
178		if name in self.__dict__["storage"]:
179			del self.__dict__["storage"][name]
180
181	def _permute(self, ordering: np.ndarray) -> None:
182		for name in self.keys():
183			g = self[name]
184			(a, b, w) = (g.row, g.col, g.data)
185			a = _renumber(a, np.array(ordering), np.arange(g.shape[1]))
186			b = _renumber(b, np.array(ordering), np.arange(g.shape[1]))
187			g = sparse.coo_matrix((w, (a, b)), g.shape)
188			self[name] = g
189