1import os
2
3import numpy as np
4
5from yt.data_objects.static_output import ParticleDataset, ParticleFile
6from yt.funcs import setdefaultattr
7from yt.geometry.particle_geometry_handler import ParticleIndex
8from yt.utilities.logger import ytLogger as mylog
9from yt.utilities.on_demand_imports import _requests as requests
10from yt.utilities.sdf import HTTPSDFRead, SDFIndex, SDFRead
11
12from .fields import SDFFieldInfo
13
14# currently specified by units_2HOT == 2 in header
15# in future will read directly from file
16units_2HOT_v2_length = 3.08567802e21
17units_2HOT_v2_mass = 1.98892e43
18units_2HOT_v2_time = 3.1558149984e16
19
20
21class SDFFile(ParticleFile):
22    pass
23
24
25class SDFDataset(ParticleDataset):
26    _index_class = ParticleIndex
27    _file_class = SDFFile
28    _field_info_class = SDFFieldInfo
29    _particle_mass_name = None
30    _particle_coordinates_name = None
31    _particle_velocity_name = None
32    _midx = None
33    _skip_cache = True
34    _subspace = False
35
36    def __init__(
37        self,
38        filename,
39        dataset_type="sdf_particles",
40        index_order=None,
41        index_filename=None,
42        bounding_box=None,
43        sdf_header=None,
44        midx_filename=None,
45        midx_header=None,
46        midx_level=None,
47        field_map=None,
48        units_override=None,
49        unit_system="cgs",
50    ):
51        if bounding_box is not None:
52            # This ensures that we know a bounding box has been applied
53            self._domain_override = True
54            self._subspace = True
55            bbox = np.array(bounding_box, dtype="float64")
56            if bbox.shape == (2, 3):
57                bbox = bbox.transpose()
58            self.domain_left_edge = bbox[:, 0]
59            self.domain_right_edge = bbox[:, 1]
60        else:
61            self.domain_left_edge = self.domain_right_edge = None
62        self.sdf_header = sdf_header
63        self.midx_filename = midx_filename
64        self.midx_header = midx_header
65        self.midx_level = midx_level
66        if field_map is None:
67            field_map = {}
68        self._field_map = field_map
69        prefix = ""
70        if self.midx_filename is not None:
71            prefix += "midx_"
72        if filename.startswith("http"):
73            prefix += "http_"
74        dataset_type = prefix + "sdf_particles"
75        super().__init__(
76            filename,
77            dataset_type=dataset_type,
78            units_override=units_override,
79            unit_system=unit_system,
80            index_order=index_order,
81            index_filename=index_filename,
82        )
83
84    def _parse_parameter_file(self):
85        if self.parameter_filename.startswith("http"):
86            sdf_class = HTTPSDFRead
87        else:
88            sdf_class = SDFRead
89        self.sdf_container = sdf_class(self.parameter_filename, header=self.sdf_header)
90
91        # Reference
92        self.parameters = self.sdf_container.parameters
93        self.dimensionality = 3
94        self.refine_by = 2
95
96        if self.domain_left_edge is None or self.domain_right_edge is None:
97            R0 = self.parameters["R0"]
98            if "offset_center" in self.parameters and self.parameters["offset_center"]:
99                self.domain_left_edge = np.array([0, 0, 0], dtype=np.float64)
100                self.domain_right_edge = np.array(
101                    [2.0 * self.parameters.get(f"R{ax}", R0) for ax in "xyz"],
102                    dtype=np.float64,
103                )
104            else:
105                self.domain_left_edge = np.array(
106                    [-self.parameters.get(f"R{ax}", R0) for ax in "xyz"],
107                    dtype=np.float64,
108                )
109                self.domain_right_edge = np.array(
110                    [+self.parameters.get(f"R{ax}", R0) for ax in "xyz"],
111                    dtype=np.float64,
112                )
113            self.domain_left_edge *= self.parameters.get("a", 1.0)
114            self.domain_right_edge *= self.parameters.get("a", 1.0)
115
116        self.domain_dimensions = np.ones(3, "int32")
117        if "do_periodic" in self.parameters and self.parameters["do_periodic"]:
118            self._periodicity = (True, True, True)
119        else:
120            self._periodicity = (False, False, False)
121
122        self.cosmological_simulation = 1
123
124        self.current_redshift = self.parameters.get("redshift", 0.0)
125        self.omega_lambda = self.parameters["Omega0_lambda"]
126        self.omega_matter = self.parameters["Omega0_m"]
127        if "Omega0_fld" in self.parameters:
128            self.omega_lambda += self.parameters["Omega0_fld"]
129        if "Omega0_r" in self.parameters:
130            # not correct, but most codes can't handle Omega0_r
131            self.omega_matter += self.parameters["Omega0_r"]
132        self.hubble_constant = self.parameters["h_100"]
133        self.current_time = units_2HOT_v2_time * self.parameters.get("tpos", 0.0)
134        mylog.info("Calculating time to be %0.3e seconds", self.current_time)
135        self.filename_template = self.parameter_filename
136        self.file_count = 1
137
138    @property
139    def midx(self):
140        if self._midx is None:
141            if self.midx_filename is not None:
142
143                if "http" in self.midx_filename:
144                    sdf_class = HTTPSDFRead
145                else:
146                    sdf_class = SDFRead
147                indexdata = sdf_class(self.midx_filename, header=self.midx_header)
148                self._midx = SDFIndex(
149                    self.sdf_container, indexdata, level=self.midx_level
150                )
151            else:
152                raise RuntimeError("SDF index0 file not supplied in load.")
153        return self._midx
154
155    def _set_code_unit_attributes(self):
156        setdefaultattr(
157            self,
158            "length_unit",
159            self.quan(1.0, self.parameters.get("length_unit", "kpc")),
160        )
161        setdefaultattr(
162            self,
163            "velocity_unit",
164            self.quan(1.0, self.parameters.get("velocity_unit", "kpc/Gyr")),
165        )
166        setdefaultattr(
167            self, "time_unit", self.quan(1.0, self.parameters.get("time_unit", "Gyr"))
168        )
169        mass_unit = self.parameters.get("mass_unit", "1e10 Msun")
170        if " " in mass_unit:
171            factor, unit = mass_unit.split(" ")
172        else:
173            factor = 1.0
174            unit = mass_unit
175        setdefaultattr(self, "mass_unit", self.quan(float(factor), unit))
176
177    @classmethod
178    def _is_valid(cls, filename, *args, **kwargs):
179        sdf_header = kwargs.get("sdf_header", filename)
180        if sdf_header.startswith("http"):
181            try:
182                hreq = requests.get(sdf_header, stream=True)
183            except ImportError:
184                # requests is not installed
185                return False
186            if hreq.status_code != 200:
187                return False
188            # Grab a whole 4k page.
189            line = next(hreq.iter_content(4096))
190        elif os.path.isfile(sdf_header):
191            with open(sdf_header, encoding="ISO-8859-1") as f:
192                line = f.read(10).strip()
193        else:
194            return False
195        return line.startswith("# SDF")
196