1from collections import defaultdict
2
3import numpy as np
4
5from yt.utilities.cython_fortran_utils import FortranFile
6from yt.utilities.exceptions import (
7    YTFieldTypeNotFound,
8    YTFileNotParseable,
9    YTParticleOutputFormatNotImplemented,
10)
11from yt.utilities.io_handler import BaseIOHandler
12from yt.utilities.logger import ytLogger as mylog
13from yt.utilities.physical_ratios import cm_per_km, cm_per_mpc
14
15from .definitions import VAR_DESC_RE, VERSION_RE
16
17
18def convert_ramses_ages(ds, conformal_ages):
19    tf = ds.t_frw
20    dtau = ds.dtau
21    tauf = ds.tau_frw
22    tsim = ds.time_simu
23    h100 = ds.hubble_constant
24    nOver2 = ds.n_frw / 2
25    unit_t = ds.parameters["unit_t"]
26    t_scale = 1.0 / (h100 * 100 * cm_per_km / cm_per_mpc) / unit_t
27
28    # calculate index into lookup table (n_frw elements in
29    # lookup table)
30    dage = 1 + (10 * conformal_ages / dtau)
31    dage = np.minimum(dage, nOver2 + (dage - nOver2) / 10.0)
32    iage = np.array(dage, dtype=np.int32)
33
34    # linearly interpolate physical times from tf and tauf lookup
35    # tables.
36    t = tf[iage] * (conformal_ages - tauf[iage - 1]) / (tauf[iage] - tauf[iage - 1])
37    t = t + (
38        tf[iage - 1] * (conformal_ages - tauf[iage]) / (tauf[iage - 1] - tauf[iage])
39    )
40    return (tsim - t) * t_scale
41
42
43def _ramses_particle_file_handler(fname, foffsets, data_types, subset, fields, count):
44    """General file handler, called by _read_particle_subset
45
46    Parameters
47    ----------
48    fname : string
49        filename to read from
50    foffsets: dict
51        Offsets in file of the fields
52    data_types: dict
53        Data type of the fields
54    subset: ``RAMSESDomainSubset``
55        A RAMSES domain subset object
56    fields: list of tuple
57        The fields to read
58    count: integer
59        The number of elements to count
60    """
61    tr = {}
62    ds = subset.domain.ds
63    with FortranFile(fname) as fd:
64        # We do *all* conversion into boxlen here.
65        # This means that no other conversions need to be applied to convert
66        # positions into the same domain as the octs themselves.
67        for field in sorted(fields, key=lambda a: foffsets[a]):
68            if count == 0:
69                tr[field] = np.empty(0, dtype=data_types[field])
70                continue
71            fd.seek(foffsets[field])
72            dt = data_types[field]
73            tr[field] = fd.read_vector(dt)
74            if field[1].startswith("particle_position"):
75                np.divide(tr[field], ds["boxlen"], tr[field])
76            if ds.cosmological_simulation and field[1] == "particle_birth_time":
77                conformal_age = tr[field]
78                tr[field] = convert_ramses_ages(ds, conformal_age)
79                # arbitrarily set particles with zero conformal_age to zero
80                # particle_age. This corresponds to DM particles.
81                tr[field][conformal_age == 0] = 0
82    return tr
83
84
85class IOHandlerRAMSES(BaseIOHandler):
86    _dataset_type = "ramses"
87
88    def _read_fluid_selection(self, chunks, selector, fields, size):
89        tr = defaultdict(list)
90
91        # Set of field types
92        ftypes = {f[0] for f in fields}
93        for chunk in chunks:
94            # Gather fields by type to minimize i/o operations
95            for ft in ftypes:
96                # Get all the fields of the same type
97                field_subs = list(filter(lambda f: f[0] == ft, fields))
98
99                # Loop over subsets
100                for subset in chunk.objs:
101                    fname = None
102                    for fh in subset.domain.field_handlers:
103                        if fh.ftype == ft:
104                            file_handler = fh
105                            fname = fh.fname
106                            break
107
108                    if fname is None:
109                        raise YTFieldTypeNotFound(ft)
110
111                    # Now we read the entire thing
112                    with FortranFile(fname) as fd:
113                        # This contains the boundary information, so we skim through
114                        # and pick off the right vectors
115                        rv = subset.fill(fd, field_subs, selector, file_handler)
116                    for ft, f in field_subs:
117                        d = rv.pop(f)
118                        mylog.debug(
119                            "Filling %s with %s (%0.3e %0.3e) (%s zones)",
120                            f,
121                            d.size,
122                            d.min(),
123                            d.max(),
124                            d.size,
125                        )
126                        tr[(ft, f)].append(d)
127        d = {}
128        for field in fields:
129            d[field] = np.concatenate(tr.pop(field))
130
131        return d
132
133    def _read_particle_coords(self, chunks, ptf):
134        pn = "particle_position_%s"
135        fields = [
136            (ptype, f"particle_position_{ax}")
137            for ptype, field_list in ptf.items()
138            for ax in "xyz"
139        ]
140        for chunk in chunks:
141            for subset in chunk.objs:
142                rv = self._read_particle_subset(subset, fields)
143                for ptype in sorted(ptf):
144                    yield ptype, (
145                        rv[ptype, pn % "x"],
146                        rv[ptype, pn % "y"],
147                        rv[ptype, pn % "z"],
148                    )
149
150    def _read_particle_fields(self, chunks, ptf, selector):
151        pn = "particle_position_%s"
152        chunks = list(chunks)
153        fields = [
154            (ptype, fname) for ptype, field_list in ptf.items() for fname in field_list
155        ]
156        for ptype, field_list in sorted(ptf.items()):
157            for ax in "xyz":
158                if pn % ax not in field_list:
159                    fields.append((ptype, pn % ax))
160        for chunk in chunks:
161            for subset in chunk.objs:
162                rv = self._read_particle_subset(subset, fields)
163                for ptype, field_list in sorted(ptf.items()):
164                    x, y, z = (np.asarray(rv[ptype, pn % ax], "=f8") for ax in "xyz")
165                    mask = selector.select_points(x, y, z, 0.0)
166                    if mask is None:
167                        mask = []
168                    for field in field_list:
169                        data = np.asarray(rv.pop((ptype, field))[mask], "=f8")
170                        yield (ptype, field), data
171
172    def _read_particle_subset(self, subset, fields):
173        """Read the particle files."""
174        tr = {}
175
176        # Sequential read depending on particle type
177        for ptype in {f[0] for f in fields}:
178
179            # Select relevant fiels
180            subs_fields = filter(lambda f: f[0] == ptype, fields)
181
182            ok = False
183            for ph in subset.domain.particle_handlers:
184                if ph.ptype == ptype:
185                    fname = ph.fname
186                    foffsets = ph.field_offsets
187                    data_types = ph.field_types
188                    ok = True
189                    count = ph.local_particle_count
190                    break
191            if not ok:
192                raise YTFieldTypeNotFound(ptype)
193
194            cosmo = self.ds.cosmological_simulation
195            if (ptype, "particle_birth_time") in foffsets and cosmo:
196                foffsets[ptype, "conformal_birth_time"] = foffsets[
197                    ptype, "particle_birth_time"
198                ]
199                data_types[ptype, "conformal_birth_time"] = data_types[
200                    ptype, "particle_birth_time"
201                ]
202
203            tr.update(
204                _ramses_particle_file_handler(
205                    fname, foffsets, data_types, subset, subs_fields, count=count
206                )
207            )
208
209        return tr
210
211
212def _read_part_file_descriptor(fname):
213    """
214    Read a file descriptor and returns the array of the fields found.
215    """
216
217    # Mapping
218    mapping = [
219        ("position_x", "particle_position_x"),
220        ("position_y", "particle_position_y"),
221        ("position_z", "particle_position_z"),
222        ("velocity_x", "particle_velocity_x"),
223        ("velocity_y", "particle_velocity_y"),
224        ("velocity_z", "particle_velocity_z"),
225        ("mass", "particle_mass"),
226        ("identity", "particle_identity"),
227        ("levelp", "particle_level"),
228        ("family", "particle_family"),
229        ("tag", "particle_tag"),
230    ]
231    # Convert to dictionary
232    mapping = {k: v for k, v in mapping}
233
234    with open(fname) as f:
235        line = f.readline()
236        tmp = VERSION_RE.match(line)
237        mylog.debug("Reading part file descriptor %s.", fname)
238        if not tmp:
239            raise YTParticleOutputFormatNotImplemented()
240
241        version = int(tmp.group(1))
242
243        if version == 1:
244            # Skip one line (containing the headers)
245            line = f.readline()
246            fields = []
247            for i, line in enumerate(f.readlines()):
248                tmp = VAR_DESC_RE.match(line)
249                if not tmp:
250                    raise YTFileNotParseable(fname, i + 1)
251
252                # ivar = tmp.group(1)
253                varname = tmp.group(2)
254                dtype = tmp.group(3)
255
256                if varname in mapping:
257                    varname = mapping[varname]
258                else:
259                    varname = f"particle_{varname}"
260
261                fields.append((varname, dtype))
262        else:
263            raise YTParticleOutputFormatNotImplemented()
264
265    return fields
266
267
268def _read_fluid_file_descriptor(fname):
269    """
270    Read a file descriptor and returns the array of the fields found.
271    """
272
273    # Mapping
274    mapping = [
275        ("density", "Density"),
276        ("velocity_x", "x-velocity"),
277        ("velocity_y", "y-velocity"),
278        ("velocity_z", "z-velocity"),
279        ("pressure", "Pressure"),
280        ("metallicity", "Metallicity"),
281    ]
282
283    # Add mapping for magnetic fields
284    mapping += [
285        (key, key)
286        for key in (
287            f"B_{dim}_{side}" for side in ["left", "right"] for dim in ["x", "y", "z"]
288        )
289    ]
290
291    # Convert to dictionary
292    mapping = {k: v for k, v in mapping}
293
294    with open(fname) as f:
295        line = f.readline()
296        tmp = VERSION_RE.match(line)
297        mylog.debug("Reading fluid file descriptor %s.", fname)
298        if not tmp:
299            return []
300
301        version = int(tmp.group(1))
302
303        if version == 1:
304            # Skip one line (containing the headers)
305            line = f.readline()
306            fields = []
307            for i, line in enumerate(f.readlines()):
308                tmp = VAR_DESC_RE.match(line)
309                if not tmp:
310                    raise YTFileNotParseable(fname, i + 1)
311
312                # ivar = tmp.group(1)
313                varname = tmp.group(2)
314                dtype = tmp.group(3)
315
316                if varname in mapping:
317                    varname = mapping[varname]
318                else:
319                    varname = f"hydro_{varname}"
320
321                fields.append((varname, dtype))
322        else:
323            mylog.error("Version %s", version)
324            raise YTParticleOutputFormatNotImplemented()
325
326    return fields
327