1import dask
2from distributed.client import Client, _get_global_client
3from distributed.worker import get_worker
4
5from fsspec import filesystem
6from fsspec.spec import AbstractBufferedFile, AbstractFileSystem
7from fsspec.utils import infer_storage_options
8
9
10def _get_client(client):
11    if client is None:
12        return _get_global_client()
13    elif isinstance(client, Client):
14        return client
15    else:
16        # e.g., connection string
17        return Client(client)
18
19
20class DaskWorkerFileSystem(AbstractFileSystem):
21    """View files accessible to a worker as any other remote file-system
22
23    When instances are run on the worker, uses the real filesystem. When
24    run on the client, they call the worker to provide information or data.
25
26    **Warning** this implementation is experimental, and read-only for now.
27    """
28
29    def __init__(
30        self, target_protocol=None, target_options=None, fs=None, client=None, **kwargs
31    ):
32        super().__init__(**kwargs)
33        if not (fs is None) ^ (target_protocol is None):
34            raise ValueError(
35                "Please provide one of filesystem instance (fs) or"
36                " target_protocol, not both"
37            )
38        self.target_protocol = target_protocol
39        self.target_options = target_options
40        self.worker = None
41        self.client = client
42        self.fs = fs
43        self._determine_worker()
44
45    @staticmethod
46    def _get_kwargs_from_urls(path):
47        so = infer_storage_options(path)
48        if "host" in so and "port" in so:
49            return {"client": f"{so['host']}:{so['port']}"}
50        else:
51            return {}
52
53    def _determine_worker(self):
54        try:
55            get_worker()
56            self.worker = True
57            if self.fs is None:
58                self.fs = filesystem(
59                    self.target_protocol, **(self.target_options or {})
60                )
61        except ValueError:
62            self.worker = False
63            self.client = _get_client(self.client)
64            self.rfs = dask.delayed(self)
65
66    def mkdir(self, *args, **kwargs):
67        if self.worker:
68            self.fs.mkdir(*args, **kwargs)
69        else:
70            self.rfs.mkdir(*args, **kwargs).compute()
71
72    def rm(self, *args, **kwargs):
73        if self.worker:
74            self.fs.rm(*args, **kwargs)
75        else:
76            self.rfs.rm(*args, **kwargs).compute()
77
78    def copy(self, *args, **kwargs):
79        if self.worker:
80            self.fs.copy(*args, **kwargs)
81        else:
82            self.rfs.copy(*args, **kwargs).compute()
83
84    def mv(self, *args, **kwargs):
85        if self.worker:
86            self.fs.mv(*args, **kwargs)
87        else:
88            self.rfs.mv(*args, **kwargs).compute()
89
90    def ls(self, *args, **kwargs):
91        if self.worker:
92            return self.fs.ls(*args, **kwargs)
93        else:
94            return self.rfs.ls(*args, **kwargs).compute()
95
96    def _open(
97        self,
98        path,
99        mode="rb",
100        block_size=None,
101        autocommit=True,
102        cache_options=None,
103        **kwargs,
104    ):
105        if self.worker:
106            return self.fs._open(
107                path,
108                mode=mode,
109                block_size=block_size,
110                autocommit=autocommit,
111                cache_options=cache_options,
112                **kwargs,
113            )
114        else:
115            return DaskFile(
116                fs=self,
117                path=path,
118                mode=mode,
119                block_size=block_size,
120                autocommit=autocommit,
121                cache_options=cache_options,
122                **kwargs,
123            )
124
125    def fetch_range(self, path, mode, start, end):
126        if self.worker:
127            with self._open(path, mode) as f:
128                f.seek(start)
129                return f.read(end - start)
130        else:
131            return self.rfs.fetch_range(path, mode, start, end).compute()
132
133
134class DaskFile(AbstractBufferedFile):
135    def __init__(self, mode="rb", **kwargs):
136        if mode != "rb":
137            raise ValueError('Remote dask files can only be opened in "rb" mode')
138        super().__init__(**kwargs)
139
140    def _upload_chunk(self, final=False):
141        pass
142
143    def _initiate_upload(self):
144        """Create remote file/upload"""
145        pass
146
147    def _fetch_range(self, start, end):
148        """Get the specified set of bytes from remote"""
149        return self.fs.fetch_range(self.path, self.mode, start, end)
150