1import re
2import sys
3from numbers import Number as numeric_type
4
5import numpy as np
6from more_itertools import first, mark_ends
7
8from yt.data_objects.construction_data_containers import YTCoveringGrid
9from yt.data_objects.image_array import ImageArray
10from yt.fields.derived_field import DerivedField
11from yt.funcs import fix_axis, is_sequence, iter_fields, mylog
12from yt.units import dimensions
13from yt.units.unit_object import Unit
14from yt.units.yt_array import YTArray, YTQuantity
15from yt.utilities.on_demand_imports import _astropy
16from yt.utilities.parallel_tools.parallel_analysis_interface import parallel_root_only
17from yt.visualization.fixed_resolution import FixedResolutionBuffer, ParticleImageBuffer
18from yt.visualization.particle_plots import ParticleAxisAlignedDummyDataSource
19from yt.visualization.volume_rendering.off_axis_projection import off_axis_projection
20
21
22class UnitfulHDU:
23    def __init__(self, hdu):
24        self.hdu = hdu
25        self.header = self.hdu.header
26        self.name = self.header["BTYPE"]
27        self.units = self.header["BUNIT"]
28        self.shape = self.hdu.shape
29
30    @property
31    def data(self):
32        return YTArray(self.hdu.data, self.units)
33
34    def __repr__(self):
35        im_shape = " x ".join(str(s) for s in self.shape)
36        return f"FITSImage: {self.name} ({im_shape}, {self.units})"
37
38
39class FITSImageData:
40    def __init__(
41        self,
42        data,
43        fields=None,
44        length_unit=None,
45        width=None,
46        img_ctr=None,
47        wcs=None,
48        current_time=None,
49        time_unit=None,
50        mass_unit=None,
51        velocity_unit=None,
52        magnetic_unit=None,
53        ds=None,
54        unit_header=None,
55        **kwargs,
56    ):
57        r"""Initialize a FITSImageData object.
58
59        FITSImageData contains a collection of FITS ImageHDU instances and
60        WCS information, along with units for each of the images. FITSImageData
61        instances can be constructed from ImageArrays, NumPy arrays, dicts
62        of such arrays, FixedResolutionBuffers, and YTCoveringGrids. The latter
63        two are the most powerful because WCS information can be constructed
64        automatically from their coordinates.
65
66        Parameters
67        ----------
68        data : FixedResolutionBuffer or a YTCoveringGrid. Or, an
69            ImageArray, an numpy.ndarray, or dict of such arrays
70            The data to be made into a FITS image or images.
71        fields : single string or list of strings, optional
72            The field names for the data. If *fields* is none and *data* has
73            keys, it will use these for the fields. If *data* is just a
74            single array one field name must be specified.
75        length_unit : string
76            The units of the WCS coordinates and the length unit of the file.
77            Defaults to the length unit of the dataset, if there is one, or
78            "cm" if there is not.
79        width : float or YTQuantity
80            The width of the image. Either a single value or iterable of values.
81            If a float, assumed to be in *units*. Only used if this information
82            is not already provided by *data*.
83        img_ctr : array_like or YTArray
84            The center coordinates of the image. If a list or NumPy array,
85            it is assumed to be in *units*. Only used if this information
86            is not already provided by *data*.
87        wcs : `~astropy.wcs.WCS` instance, optional
88            Supply an AstroPy WCS instance. Will override automatic WCS
89            creation from FixedResolutionBuffers and YTCoveringGrids.
90        current_time : float, tuple, or YTQuantity, optional
91            The current time of the image(s). If not specified, one will
92            be set from the dataset if there is one. If a float, it will
93            be assumed to be in *time_unit* units.
94        time_unit : string
95            The default time units of the file. Defaults to "s".
96        mass_unit : string
97            The default time units of the file. Defaults to "g".
98        velocity_unit : string
99            The default velocity units of the file. Defaults to "cm/s".
100        magnetic_unit : string
101            The default magnetic units of the file. Defaults to "gauss".
102        ds : `~yt.static_output.Dataset` instance, optional
103            The dataset associated with the image(s), typically used
104            to transfer metadata to the header(s). Does not need to be
105            specified if *data* has a dataset as an attribute.
106
107        Examples
108        --------
109
110        >>> # This example uses a FRB.
111        >>> ds = load("sloshing_nomag2_hdf5_plt_cnt_0150")
112        >>> prj = ds.proj(2, "kT", weight_field=("gas", "density"))
113        >>> frb = prj.to_frb((0.5, "Mpc"), 800)
114        >>> # This example just uses the FRB and puts the coords in kpc.
115        >>> f_kpc = FITSImageData(
116        ...     frb, fields="kT", length_unit="kpc", time_unit=(1.0, "Gyr")
117        ... )
118        >>> # This example specifies a specific WCS.
119        >>> from astropy.wcs import WCS
120        >>> w = WCS(naxis=self.dimensionality)
121        >>> w.wcs.crval = [30.0, 45.0]  # RA, Dec in degrees
122        >>> w.wcs.cunit = ["deg"] * 2
123        >>> nx, ny = 800, 800
124        >>> w.wcs.crpix = [0.5 * (nx + 1), 0.5 * (ny + 1)]
125        >>> w.wcs.ctype = ["RA---TAN", "DEC--TAN"]
126        >>> scale = 1.0 / 3600.0  # One arcsec per pixel
127        >>> w.wcs.cdelt = [-scale, scale]
128        >>> f_deg = FITSImageData(frb, fields="kT", wcs=w)
129        >>> f_deg.writeto("temp.fits")
130        """
131
132        if fields is not None:
133            fields = list(iter_fields(fields))
134
135        if ds is None:
136            ds = getattr(data, "ds", None)
137
138        self.fields = []
139        self.field_units = {}
140
141        if unit_header is None:
142            self._set_units(
143                ds, [length_unit, mass_unit, time_unit, velocity_unit, magnetic_unit]
144            )
145        else:
146            self._set_units_from_header(unit_header)
147
148        wcs_unit = str(self.length_unit.units)
149
150        self._fix_current_time(ds, current_time)
151
152        if width is None:
153            width = 1.0
154        if isinstance(width, tuple):
155            if ds is None:
156                width = YTQuantity(width[0], width[1])
157            else:
158                width = ds.quan(width[0], width[1])
159        if img_ctr is None:
160            img_ctr = np.zeros(3)
161
162        exclude_fields = [
163            "x",
164            "y",
165            "z",
166            "px",
167            "py",
168            "pz",
169            "pdx",
170            "pdy",
171            "pdz",
172            "weight_field",
173        ]
174
175        if isinstance(data, _astropy.pyfits.PrimaryHDU):
176            data = _astropy.pyfits.HDUList([data])
177
178        if isinstance(data, _astropy.pyfits.HDUList):
179            self.hdulist = data
180            for hdu in data:
181                self.fields.append(hdu.header["btype"])
182                self.field_units[hdu.header["btype"]] = hdu.header["bunit"]
183
184            self.shape = self.hdulist[0].shape
185            self.dimensionality = len(self.shape)
186            wcs_names = [key for key in self.hdulist[0].header if "WCSNAME" in key]
187            for name in wcs_names:
188                if name == "WCSNAME":
189                    key = " "
190                else:
191                    key = name[-1]
192                w = _astropy.pywcs.WCS(
193                    header=self.hdulist[0].header, key=key, naxis=self.dimensionality
194                )
195                setattr(self, "wcs" + key.strip().lower(), w)
196
197            return
198
199        self.hdulist = _astropy.pyfits.HDUList()
200
201        if hasattr(data, "keys"):
202            img_data = data
203            if fields is None:
204                fields = list(img_data.keys())
205        elif isinstance(data, np.ndarray):
206            if fields is None:
207                mylog.warning(
208                    "No field name given for this array. Calling it 'image_data'."
209                )
210                fn = "image_data"
211                fields = [fn]
212            else:
213                fn = fields[0]
214            img_data = {fn: data}
215
216        for fd in fields:
217            if isinstance(fd, tuple):
218                self.fields.append(fd[1])
219            elif isinstance(fd, DerivedField):
220                self.fields.append(fd.name[1])
221            else:
222                self.fields.append(fd)
223
224        # Sanity checking names
225        s = set()
226        duplicates = {f for f in self.fields if f in s or s.add(f)}
227        if len(duplicates) > 0:
228            for i, fd in enumerate(self.fields):
229                if fd in duplicates:
230                    if isinstance(fields[i], tuple):
231                        ftype, fname = fields[i]
232                    elif isinstance(fields[i], DerivedField):
233                        ftype, fname = fields[i].name
234                    else:
235                        raise RuntimeError(
236                            f"Cannot distinguish between fields with same name {fd}!"
237                        )
238                    self.fields[i] = f"{ftype}_{fname}"
239
240        for is_first, _is_last, (i, (name, field)) in mark_ends(
241            enumerate(zip(self.fields, fields))
242        ):
243            if name not in exclude_fields:
244                this_img = img_data[field]
245                if hasattr(img_data[field], "units"):
246                    has_code_unit = False
247                    for atom in this_img.units.expr.atoms():
248                        if str(atom).startswith("code"):
249                            has_code_unit = True
250                    if has_code_unit:
251                        mylog.warning(
252                            "Cannot generate an image with code "
253                            "units. Converting to units in CGS."
254                        )
255                        funits = this_img.units.get_base_equivalent("cgs")
256                        this_img.convert_to_units(funits)
257                    else:
258                        funits = this_img.units
259                    self.field_units[name] = str(funits)
260                else:
261                    self.field_units[name] = "dimensionless"
262                mylog.info("Making a FITS image of field %s", name)
263                if isinstance(this_img, ImageArray):
264                    if i == 0:
265                        self.shape = this_img.shape[::-1]
266                    this_img = np.asarray(this_img)
267                else:
268                    if i == 0:
269                        self.shape = this_img.shape
270                    this_img = np.asarray(this_img.T)
271                if is_first:
272                    hdu = _astropy.pyfits.PrimaryHDU(this_img)
273                else:
274                    hdu = _astropy.pyfits.ImageHDU(this_img)
275                hdu.name = name
276                hdu.header["btype"] = name
277                hdu.header["bunit"] = re.sub("()", "", self.field_units[name])
278                for unit in ("length", "time", "mass", "velocity", "magnetic"):
279                    if unit == "magnetic":
280                        short_unit = "bf"
281                    else:
282                        short_unit = unit[0]
283                    key = f"{short_unit}unit"
284                    value = getattr(self, f"{unit}_unit")
285                    if value is not None:
286                        hdu.header[key] = float(value.value)
287                        hdu.header.comments[key] = f"[{value.units}]"
288                hdu.header["time"] = float(self.current_time.value)
289                if hasattr(self, "current_redshift"):
290                    hdu.header["HUBBLE"] = self.hubble_constant
291                    hdu.header["REDSHIFT"] = self.current_redshift
292                self.hdulist.append(hdu)
293
294        self.dimensionality = len(self.shape)
295
296        if wcs is None:
297            w = _astropy.pywcs.WCS(
298                header=self.hdulist[0].header, naxis=self.dimensionality
299            )
300            # FRBs and covering grids are special cases where
301            # we have coordinate information, so we take advantage
302            # of this and construct the WCS object
303            if isinstance(img_data, FixedResolutionBuffer):
304                dx = (img_data.bounds[1] - img_data.bounds[0]).to_value(wcs_unit)
305                dy = (img_data.bounds[3] - img_data.bounds[2]).to_value(wcs_unit)
306                dx /= self.shape[0]
307                dy /= self.shape[1]
308                xctr = 0.5 * (img_data.bounds[1] + img_data.bounds[0]).to_value(
309                    wcs_unit
310                )
311                yctr = 0.5 * (img_data.bounds[3] + img_data.bounds[2]).to_value(
312                    wcs_unit
313                )
314                center = [xctr, yctr]
315                cdelt = [dx, dy]
316            elif isinstance(img_data, YTCoveringGrid):
317                cdelt = img_data.dds.to_value(wcs_unit)
318                center = 0.5 * (img_data.left_edge + img_data.right_edge).to_value(
319                    wcs_unit
320                )
321            else:
322                # If img_data is just an array we use the width and img_ctr
323                # parameters to determine the cell widths
324                if not is_sequence(width):
325                    width = [width] * self.dimensionality
326                if isinstance(width[0], YTQuantity):
327                    cdelt = [
328                        wh.to_value(wcs_unit) / n for wh, n in zip(width, self.shape)
329                    ]
330                else:
331                    cdelt = [float(wh) / n for wh, n in zip(width, self.shape)]
332                center = img_ctr[: self.dimensionality]
333            w.wcs.crpix = 0.5 * (np.array(self.shape) + 1)
334            w.wcs.crval = center
335            w.wcs.cdelt = cdelt
336            w.wcs.ctype = ["linear"] * self.dimensionality
337            w.wcs.cunit = [wcs_unit] * self.dimensionality
338            self.set_wcs(w)
339        else:
340            self.set_wcs(wcs)
341
342    def _fix_current_time(self, ds, current_time):
343        if ds is None:
344            registry = None
345        else:
346            registry = ds.unit_registry
347        tunit = Unit(self.time_unit, registry=registry)
348        if current_time is None:
349            if ds is not None:
350                current_time = ds.current_time
351            else:
352                self.current_time = YTQuantity(0.0, "s")
353                return
354        elif isinstance(current_time, numeric_type):
355            current_time = YTQuantity(current_time, tunit)
356        elif isinstance(current_time, tuple):
357            current_time = YTQuantity(current_time[0], current_time[1])
358        self.current_time = current_time.to(tunit)
359
360    def _set_units(self, ds, base_units):
361        if ds is not None:
362            if getattr(ds, "cosmological_simulation", False):
363                self.hubble_constant = ds.hubble_constant
364                self.current_redshift = ds.current_redshift
365        attrs = (
366            "length_unit",
367            "mass_unit",
368            "time_unit",
369            "velocity_unit",
370            "magnetic_unit",
371        )
372        cgs_units = ("cm", "g", "s", "cm/s", "gauss")
373        for unit, attr, cgs_unit in zip(base_units, attrs, cgs_units):
374            if unit is None:
375                if ds is not None:
376                    u = getattr(ds, attr, None)
377                elif attr == "velocity_unit":
378                    u = self.length_unit / self.time_unit
379                elif attr == "magnetic_unit":
380                    u = np.sqrt(
381                        4.0
382                        * np.pi
383                        * self.mass_unit
384                        / (self.time_unit ** 2 * self.length_unit)
385                    )
386                else:
387                    u = cgs_unit
388            else:
389                u = unit
390
391            if isinstance(u, str):
392                uq = YTQuantity(1.0, u)
393            elif isinstance(u, numeric_type):
394                uq = YTQuantity(u, cgs_unit)
395            elif isinstance(u, YTQuantity):
396                uq = u.copy()
397            elif isinstance(u, tuple):
398                uq = YTQuantity(u[0], u[1])
399            else:
400                uq = None
401
402            if uq is not None and hasattr(self, "hubble_constant"):
403                # Don't store cosmology units
404                atoms = {str(a) for a in uq.units.expr.atoms()}
405                if str(uq.units).startswith("cm") or "h" in atoms or "a" in atoms:
406                    uq.convert_to_cgs()
407
408            if uq is not None and uq.units.is_code_unit:
409                mylog.warning(
410                    "Cannot use code units of '%s' "
411                    "when creating a FITSImageData instance! "
412                    "Converting to a cgs equivalent.",
413                    uq.units,
414                )
415                uq.convert_to_cgs()
416
417            if attr == "length_unit" and uq.value != 1.0:
418                mylog.warning("Converting length units from %s to %s.", uq, uq.units)
419                uq = YTQuantity(1.0, uq.units)
420
421            setattr(self, attr, uq)
422
423    def _set_units_from_header(self, header):
424        if "hubble" in header:
425            self.hubble_constant = header["HUBBLE"]
426            self.current_redshift = header["REDSHIFT"]
427        for unit in ["length", "time", "mass", "velocity", "magnetic"]:
428            if unit == "magnetic":
429                key = "BFUNIT"
430            else:
431                key = unit[0].upper() + "UNIT"
432            if key not in header:
433                continue
434            u = header.comments[key].strip("[]")
435            uq = YTQuantity(header[key], u)
436            setattr(self, unit + "_unit", uq)
437
438    def set_wcs(self, wcs, wcsname=None, suffix=None):
439        """
440        Set the WCS coordinate information for all images
441        with a WCS object *wcs*.
442        """
443        if wcsname is None:
444            wcs.wcs.name = "yt"
445        else:
446            wcs.wcs.name = wcsname
447        h = wcs.to_header()
448        if suffix is None:
449            self.wcs = wcs
450        else:
451            setattr(self, "wcs" + suffix, wcs)
452        for img in self.hdulist:
453            for k, v in h.items():
454                kk = k
455                if suffix is not None:
456                    kk += suffix
457                img.header[kk] = v
458
459    def change_image_name(self, old_name, new_name):
460        """
461        Change the name of a FITS image.
462
463        Parameters
464        ----------
465        old_name : string
466            The old name of the image.
467        new_name : string
468            The new name of the image.
469        """
470        idx = self.fields.index(old_name)
471        self.hdulist[idx].name = new_name
472        self.hdulist[idx].header["BTYPE"] = new_name
473        self.field_units[new_name] = self.field_units.pop(old_name)
474        self.fields[idx] = new_name
475
476    def convolve(self, field, kernel, **kwargs):
477        """
478        Convolve an image with a kernel, either a simple
479        Gaussian kernel or one provided by AstroPy. Currently,
480        this only works for 2D images.
481
482        All keyword arguments are passed to
483        :meth:`~astropy.convolution.convolve`.
484
485        Parameters
486        ----------
487        field : string
488            The name of the field to convolve.
489        kernel : float, YTQuantity, (value, unit) tuple, or AstroPy Kernel object
490            The kernel to convolve the image with. If this is an AstroPy Kernel
491            object, the image will be convolved with it. Otherwise, it is
492            assumed that the kernel is a Gaussian and that this value is
493            the standard deviation. If a float, it is assumed that the units
494            are pixels, but a (value, unit) tuple or YTQuantity can be supplied
495            to specify the standard deviation in physical units.
496
497        Examples
498        --------
499        >>> fid = FITSSlice(ds, "z", ("gas", "density"))
500        >>> fid.convolve("density", (3.0, "kpc"))
501        """
502        if self.dimensionality == 3:
503            raise RuntimeError("Convolution currently only works for 2D FITSImageData!")
504        conv = _astropy.conv
505        if field not in self.keys():
506            raise KeyError(f"{field} not an image!")
507        idx = self.fields.index(field)
508        if not isinstance(kernel, conv.Kernel):
509            if not isinstance(kernel, numeric_type):
510                unit = str(self.wcs.wcs.cunit[0])
511                pix_scale = YTQuantity(self.wcs.wcs.cdelt[0], unit)
512                if isinstance(kernel, tuple):
513                    stddev = YTQuantity(kernel[0], kernel[1]).to(unit)
514                else:
515                    stddev = kernel.to(unit)
516                kernel = stddev / pix_scale
517            kernel = conv.Gaussian2DKernel(x_stddev=kernel)
518        self.hdulist[idx].data = conv.convolve(self.hdulist[idx].data, kernel, **kwargs)
519
520    def update_header(self, field, key, value):
521        """
522        Update the FITS header for *field* with a
523        *key*, *value* pair. If *field* == "all", all
524        headers will be updated.
525        """
526        if field == "all":
527            for img in self.hdulist:
528                img.header[key] = value
529        else:
530            if field not in self.keys():
531                raise KeyError(f"{field} not an image!")
532            idx = self.fields.index(field)
533            self.hdulist[idx].header[key] = value
534
535    def update_all_headers(self, key, value):
536        mylog.warning(
537            "update_all_headers is deprecated. "
538            "Use update_header('all', key, value) instead."
539        )
540        self.update_header("all", key, value)
541
542    def keys(self):
543        return self.fields
544
545    def has_key(self, key):
546        return key in self.fields
547
548    def values(self):
549        return [self[k] for k in self.fields]
550
551    def items(self):
552        return [(k, self[k]) for k in self.fields]
553
554    def __getitem__(self, item):
555        return UnitfulHDU(self.hdulist[item])
556
557    def __repr__(self):
558        return str([self[k] for k in self.keys()])
559
560    def info(self, output=None):
561        """
562        Summarize the info of the HDUs in this `FITSImageData`
563        instance.
564
565        Note that this function prints its results to the console---it
566        does not return a value.
567
568        Parameters
569        ----------
570        output : file, boolean, optional
571            A file-like object to write the output to.  If `False`, does not
572            output to a file and instead returns a list of tuples representing
573            the FITSImageData info.  Writes to ``sys.stdout`` by default.
574        """
575        hinfo = self.hdulist.info(output=False)
576        num_cols = len(hinfo[0])
577        if output is None:
578            output = sys.stdout
579        if num_cols == 8:
580            header = "No.    Name      Ver    Type      Cards   Dimensions   Format     Units"
581            format = "{:3d}  {:10}  {:3} {:11}  {:5d}   {}   {}   {}"
582        else:
583            header = (
584                "No.    Name         Type      Cards   Dimensions   Format     Units"
585            )
586            format = "{:3d}  {:10}  {:11}  {:5d}   {}   {}   {}"
587        if self.hdulist._file is None:
588            name = "(No file associated with this FITSImageData)"
589        else:
590            name = self.hdulist._file.name
591        results = [f"Filename: {name}", header]
592        for line in hinfo:
593            units = self.field_units[self.hdulist[line[0]].header["btype"]]
594            summary = tuple(list(line[:-1]) + [units])
595            if output:
596                results.append(format.format(*summary))
597            else:
598                results.append(summary)
599
600        if output:
601            output.write("\n".join(results))
602            output.write("\n")
603            output.flush()
604        else:
605            return results[2:]
606
607    @parallel_root_only
608    def writeto(self, fileobj, fields=None, overwrite=False, **kwargs):
609        r"""
610        Write all of the fields or a subset of them to a FITS file.
611
612        Parameters
613        ----------
614        fileobj : string
615            The name of the file to write to.
616        fields : list of strings, optional
617            The fields to write to the file. If not specified
618            all of the fields in the buffer will be written.
619        overwrite : boolean
620            Whether or not to overwrite a previously existing file.
621            Default: False
622        **kwargs
623            Additional keyword arguments are passed to
624            :meth:`~astropy.io.fits.HDUList.writeto`.
625        """
626        if fields is None:
627            hdus = self.hdulist
628        else:
629            hdus = _astropy.pyfits.HDUList()
630            for field in fields:
631                hdus.append(self.hdulist[field])
632        hdus.writeto(fileobj, overwrite=overwrite, **kwargs)
633
634    def to_glue(self, label="yt", data_collection=None):
635        """
636        Takes the data in the FITSImageData instance and exports it to
637        Glue (http://glueviz.org) for interactive analysis. Optionally
638        add a *label*. If you are already within the Glue environment, you
639        can pass a *data_collection* object, otherwise Glue will be started.
640        """
641        from glue.core import Data, DataCollection
642        from glue.core.coordinates import coordinates_from_header
643
644        try:
645            from glue.app.qt.application import GlueApplication
646        except ImportError:
647            from glue.qt.glue_application import GlueApplication
648
649        image = Data(label=label)
650        image.coords = coordinates_from_header(self.wcs.to_header())
651        for k in self.fields:
652            image.add_component(self[k].data, k)
653        if data_collection is None:
654            dc = DataCollection([image])
655            app = GlueApplication(dc)
656            app.start()
657        else:
658            data_collection.append(image)
659
660    def to_aplpy(self, **kwargs):
661        """
662        Use APLpy (http://aplpy.github.io) for plotting. Returns an
663        `aplpy.FITSFigure` instance. All keyword arguments are passed
664        to the `aplpy.FITSFigure` constructor.
665        """
666        import aplpy
667
668        return aplpy.FITSFigure(self.hdulist, **kwargs)
669
670    def get_data(self, field):
671        """
672        Return the data array of the image corresponding to *field*
673        with units attached. Deprecated.
674        """
675        return self[field].data
676
677    def set_unit(self, field, units):
678        """
679        Set the units of *field* to *units*.
680        """
681        if field not in self.keys():
682            raise KeyError(f"{field} not an image!")
683        idx = self.fields.index(field)
684        new_data = YTArray(self.hdulist[idx].data, self.field_units[field]).to(units)
685        self.hdulist[idx].data = new_data.v
686        self.hdulist[idx].header["bunit"] = units
687        self.field_units[field] = units
688
689    def pop(self, key):
690        """
691        Remove a field with name *key*
692        and return it as a new FITSImageData
693        instance.
694        """
695        if key not in self.keys():
696            raise KeyError(f"{key} not an image!")
697        idx = self.fields.index(key)
698        im = self.hdulist.pop(idx)
699        self.field_units.pop(key)
700        self.fields.remove(key)
701        f = _astropy.pyfits.PrimaryHDU(im.data, header=im.header)
702        return FITSImageData(f, current_time=f.header["TIME"], unit_header=f.header)
703
704    def close(self):
705        self.hdulist.close()
706
707    @classmethod
708    def from_file(cls, filename):
709        """
710        Generate a FITSImageData instance from one previously written to
711        disk.
712
713        Parameters
714        ----------
715        filename : string
716            The name of the file to open.
717        """
718        f = _astropy.pyfits.open(filename, lazy_load_hdus=False)
719        return cls(f, current_time=f[0].header["TIME"], unit_header=f[0].header)
720
721    @classmethod
722    def from_images(cls, image_list):
723        """
724        Generate a new FITSImageData instance from a list of FITSImageData
725        instances.
726
727        Parameters
728        ----------
729        image_list : list of FITSImageData instances
730            The images to be combined.
731        """
732        image_list = image_list if isinstance(image_list, list) else [image_list]
733        first_image = first(image_list)
734
735        w = first_image.wcs
736        img_shape = first_image.shape
737        data = []
738        for is_first, _is_last, fid in mark_ends(image_list):
739            assert_same_wcs(w, fid.wcs)
740            if img_shape != fid.shape:
741                raise RuntimeError("Images do not have the same shape!")
742            for hdu in fid.hdulist:
743                if is_first:
744                    data.append(_astropy.pyfits.PrimaryHDU(hdu.data, header=hdu.header))
745                else:
746                    data.append(_astropy.pyfits.ImageHDU(hdu.data, header=hdu.header))
747        data = _astropy.pyfits.HDUList(data)
748        return cls(
749            data,
750            current_time=first_image.current_time,
751            unit_header=first_image[0].header,
752        )
753
754    def create_sky_wcs(
755        self,
756        sky_center,
757        sky_scale,
758        ctype=None,
759        crota=None,
760        cd=None,
761        pc=None,
762        wcsname="celestial",
763        replace_old_wcs=True,
764    ):
765        """
766        Takes a Cartesian WCS and converts it to one in a
767        sky-based coordinate system.
768
769        Parameters
770        ----------
771        sky_center : iterable of floats
772            Reference coordinates of the WCS in degrees.
773        sky_scale : tuple or YTQuantity
774            Conversion between an angle unit and a length unit,
775            e.g. (3.0, "arcsec/kpc")
776        ctype : list of strings, optional
777            The type of the coordinate system to create. Default:
778            A "tangential" projection.
779        crota : 2-element ndarray, optional
780            Rotation angles between cartesian coordinates and
781            the celestial coordinates.
782        cd : 2x2-element ndarray, optional
783            Dimensioned coordinate transformation matrix.
784        pc : 2x2-element ndarray, optional
785            Coordinate transformation matrix.
786        wcsname : string, optional
787            The name of the WCS to be stored in the FITS header.
788        replace_old_wcs : boolean, optional
789            Whether or not to overwrite the default WCS of the
790            FITSImageData instance. If false, a second WCS will
791            be added to the header. Default: True.
792        """
793        if ctype is None:
794            ctype = ["RA---TAN", "DEC--TAN"]
795        old_wcs = self.wcs
796        naxis = old_wcs.naxis
797        crval = [sky_center[0], sky_center[1]]
798        if isinstance(sky_scale, YTQuantity):
799            scaleq = sky_scale
800        else:
801            scaleq = YTQuantity(sky_scale[0], sky_scale[1])
802        if scaleq.units.dimensions != dimensions.angle / dimensions.length:
803            raise RuntimeError(
804                f"sky_scale {sky_scale} not in correct dimensions of angle/length!"
805            )
806        deltas = old_wcs.wcs.cdelt
807        units = [str(unit) for unit in old_wcs.wcs.cunit]
808        new_dx = (YTQuantity(-deltas[0], units[0]) * scaleq).in_units("deg")
809        new_dy = (YTQuantity(deltas[1], units[1]) * scaleq).in_units("deg")
810        new_wcs = _astropy.pywcs.WCS(naxis=naxis)
811        cdelt = [new_dx.v, new_dy.v]
812        cunit = ["deg"] * 2
813        if naxis == 3:
814            crval.append(old_wcs.wcs.crval[2])
815            cdelt.append(old_wcs.wcs.cdelt[2])
816            ctype.append(old_wcs.wcs.ctype[2])
817            cunit.append(old_wcs.wcs.cunit[2])
818        new_wcs.wcs.crpix = old_wcs.wcs.crpix
819        new_wcs.wcs.cdelt = cdelt
820        new_wcs.wcs.crval = crval
821        new_wcs.wcs.cunit = cunit
822        new_wcs.wcs.ctype = ctype
823        if crota is not None:
824            new_wcs.wcs.crota = crota
825        if cd is not None:
826            new_wcs.wcs.cd = cd
827        if pc is not None:
828            new_wcs.wcs.cd = pc
829        if replace_old_wcs:
830            self.set_wcs(new_wcs, wcsname=wcsname)
831        else:
832            self.set_wcs(new_wcs, wcsname=wcsname, suffix="a")
833
834
835class FITSImageBuffer(FITSImageData):
836    pass
837
838
839def sanitize_fits_unit(unit):
840    if unit == "Mpc":
841        mylog.info("Changing FITS file length unit to kpc.")
842        unit = "kpc"
843    elif unit == "au":
844        unit = "AU"
845    return unit
846
847
848# This list allows one to determine which axes are the
849# correct axes of the image in a right-handed coordinate
850# system depending on which axis is sliced or projected
851axis_wcs = [[1, 2], [0, 2], [0, 1]]
852
853
854def construct_image(ds, axis, data_source, center, image_res, width, length_unit):
855    if width is None:
856        width = ds.domain_width[axis_wcs[axis]]
857        unit = ds.get_smallest_appropriate_unit(width[0])
858        mylog.info(
859            "Making an image of the entire domain, "
860            "so setting the center to the domain center."
861        )
862    else:
863        width = ds.coordinates.sanitize_width(axis, width, None)
864        unit = str(width[0].units)
865    if is_sequence(image_res):
866        nx, ny = image_res
867    else:
868        nx, ny = image_res, image_res
869    dx = width[0] / nx
870    dy = width[1] / ny
871    crpix = [0.5 * (nx + 1), 0.5 * (ny + 1)]
872    if unit == "unitary":
873        unit = ds.get_smallest_appropriate_unit(ds.domain_width.max())
874    elif unit == "code_length":
875        unit = ds.get_smallest_appropriate_unit(ds.quan(1.0, "code_length"))
876    unit = sanitize_fits_unit(unit)
877    if length_unit is None:
878        length_unit = unit
879    if any(char.isdigit() for char in length_unit) and "*" in length_unit:
880        length_unit = length_unit.split("*")[-1]
881    cunit = [length_unit] * 2
882    ctype = ["LINEAR"] * 2
883    cdelt = [dx.in_units(length_unit), dy.in_units(length_unit)]
884    if is_sequence(axis):
885        crval = center.in_units(length_unit)
886    else:
887        crval = [center[idx].in_units(length_unit) for idx in axis_wcs[axis]]
888    if hasattr(data_source, "to_frb"):
889        if is_sequence(axis):
890            frb = data_source.to_frb(width[0], (nx, ny), height=width[1])
891        else:
892            frb = data_source.to_frb(width[0], (nx, ny), center=center, height=width[1])
893    elif isinstance(data_source, ParticleAxisAlignedDummyDataSource):
894        axes = axis_wcs[axis]
895        bounds = (
896            center[axes[0]] - width[0] / 2,
897            center[axes[0]] + width[0] / 2,
898            center[axes[1]] - width[1] / 2,
899            center[axes[1]] + width[1] / 2,
900        )
901        frb = ParticleImageBuffer(
902            data_source, bounds, (nx, ny), periodic=all(ds.periodicity)
903        )
904    else:
905        frb = None
906    w = _astropy.pywcs.WCS(naxis=2)
907    w.wcs.crpix = crpix
908    w.wcs.cdelt = cdelt
909    w.wcs.crval = crval
910    w.wcs.cunit = cunit
911    w.wcs.ctype = ctype
912    return w, frb, length_unit
913
914
915def assert_same_wcs(wcs1, wcs2):
916    from numpy.testing import assert_allclose
917
918    assert wcs1.naxis == wcs2.naxis
919    for i in range(wcs1.naxis):
920        assert wcs1.wcs.cunit[i] == wcs2.wcs.cunit[i]
921        assert wcs1.wcs.ctype[i] == wcs2.wcs.ctype[i]
922    assert_allclose(wcs1.wcs.crpix, wcs2.wcs.crpix)
923    assert_allclose(wcs1.wcs.cdelt, wcs2.wcs.cdelt)
924    assert_allclose(wcs1.wcs.crval, wcs2.wcs.crval)
925    crota1 = getattr(wcs1.wcs, "crota", None)
926    crota2 = getattr(wcs2.wcs, "crota", None)
927    if crota1 is None or crota2 is None:
928        assert crota1 == crota2
929    else:
930        assert_allclose(wcs1.wcs.crota, wcs2.wcs.crota)
931    cd1 = getattr(wcs1.wcs, "cd", None)
932    cd2 = getattr(wcs2.wcs, "cd", None)
933    if cd1 is None or cd2 is None:
934        assert cd1 == cd2
935    else:
936        assert_allclose(wcs1.wcs.cd, wcs2.wcs.cd)
937    pc1 = getattr(wcs1.wcs, "pc", None)
938    pc2 = getattr(wcs2.wcs, "pc", None)
939    if pc1 is None or pc2 is None:
940        assert pc1 == pc2
941    else:
942        assert_allclose(wcs1.wcs.pc, wcs2.wcs.pc)
943
944
945class FITSSlice(FITSImageData):
946    r"""
947    Generate a FITSImageData of an on-axis slice.
948
949    Parameters
950    ----------
951    ds : :class:`~yt.data_objects.static_output.Dataset`
952        The dataset object.
953    axis : character or integer
954        The axis of the slice. One of "x","y","z", or 0,1,2.
955    fields : string or list of strings
956        The fields to slice
957    image_res : an int or 2-tuple of ints
958        Specify the resolution of the resulting image. A single value will be
959        used for both axes, whereas a tuple of values will be used for the
960        individual axes. Default: 512
961    center : A sequence of floats, a string, or a tuple.
962        The coordinate of the center of the image. If set to 'c', 'center' or
963        left blank, the plot is centered on the middle of the domain. If set
964        to 'max' or 'm', the center will be located at the maximum of the
965        ('gas', 'density') field. Centering on the max or min of a specific
966        field is supported by providing a tuple such as ("min","temperature")
967        or ("max","dark_matter_density"). Units can be specified by passing in
968        *center* as a tuple containing a coordinate and string unit name or by
969        passing in a YTArray. If a list or unitless array is supplied, code
970        units are assumed.
971    width : tuple or a float.
972        Width can have four different formats to support variable
973        x and y widths.  They are:
974
975        ==================================     =======================
976        format                                 example
977        ==================================     =======================
978        (float, string)                        (10,'kpc')
979        ((float, string), (float, string))     ((10,'kpc'),(15,'kpc'))
980        float                                  0.2
981        (float, float)                         (0.2, 0.3)
982        ==================================     =======================
983
984        For example, (10, 'kpc') specifies a width that is 10 kiloparsecs
985        wide in the x and y directions, ((10,'kpc'),(15,'kpc')) specifies a
986        width that is 10 kiloparsecs wide along the x axis and 15
987        kiloparsecs wide along the y axis.  In the other two examples, code
988        units are assumed, for example (0.2, 0.3) specifies a width that has an
989        x width of 0.2 and a y width of 0.3 in code units.
990    length_unit : string, optional
991        the length units that the coordinates are written in. The default
992        is to use the default length unit of the dataset.
993    """
994
995    def __init__(
996        self,
997        ds,
998        axis,
999        fields,
1000        image_res=512,
1001        center="c",
1002        width=None,
1003        length_unit=None,
1004        **kwargs,
1005    ):
1006        fields = list(iter_fields(fields))
1007        axis = fix_axis(axis, ds)
1008        center, dcenter = ds.coordinates.sanitize_center(center, axis)
1009        slc = ds.slice(axis, center[axis], **kwargs)
1010        w, frb, lunit = construct_image(
1011            ds, axis, slc, dcenter, image_res, width, length_unit
1012        )
1013        super().__init__(frb, fields=fields, length_unit=lunit, wcs=w)
1014
1015
1016class FITSProjection(FITSImageData):
1017    r"""
1018    Generate a FITSImageData of an on-axis projection.
1019
1020    Parameters
1021    ----------
1022    ds : :class:`~yt.data_objects.static_output.Dataset`
1023        The dataset object.
1024    axis : character or integer
1025        The axis along which to project. One of "x","y","z", or 0,1,2.
1026    fields : string or list of strings
1027        The fields to project
1028    image_res : an int or 2-tuple of ints
1029        Specify the resolution of the resulting image. A single value will be
1030        used for both axes, whereas a tuple of values will be used for the
1031        individual axes. Default: 512
1032    center : A sequence of floats, a string, or a tuple.
1033        The coordinate of the center of the image. If set to 'c', 'center' or
1034        left blank, the plot is centered on the middle of the domain. If set
1035        to 'max' or 'm', the center will be located at the maximum of the
1036        ('gas', 'density') field. Centering on the max or min of a specific
1037        field is supported by providing a tuple such as ("min","temperature")
1038        or ("max","dark_matter_density"). Units can be specified by passing in
1039        *center* as a tuple containing a coordinate and string unit name or by
1040        passing in a YTArray. If a list or unitless array is supplied, code
1041        units are assumed.
1042    width : tuple or a float.
1043        Width can have four different formats to support variable
1044        x and y widths.  They are:
1045
1046        ==================================     =======================
1047        format                                 example
1048        ==================================     =======================
1049        (float, string)                        (10,'kpc')
1050        ((float, string), (float, string))     ((10,'kpc'),(15,'kpc'))
1051        float                                  0.2
1052        (float, float)                         (0.2, 0.3)
1053        ==================================     =======================
1054
1055        For example, (10, 'kpc') specifies a width that is 10 kiloparsecs
1056        wide in the x and y directions, ((10,'kpc'),(15,'kpc')) specifies a
1057        width that is 10 kiloparsecs wide along the x axis and 15
1058        kiloparsecs wide along the y axis.  In the other two examples, code
1059        units are assumed, for example (0.2, 0.3) specifies a width that has an
1060        x width of 0.2 and a y width of 0.3 in code units.
1061    weight_field : string
1062        The field used to weight the projection.
1063    length_unit : string, optional
1064        the length units that the coordinates are written in. The default
1065        is to use the default length unit of the dataset.
1066    """
1067
1068    def __init__(
1069        self,
1070        ds,
1071        axis,
1072        fields,
1073        image_res=512,
1074        center="c",
1075        width=None,
1076        weight_field=None,
1077        length_unit=None,
1078        **kwargs,
1079    ):
1080        fields = list(iter_fields(fields))
1081        axis = fix_axis(axis, ds)
1082        center, dcenter = ds.coordinates.sanitize_center(center, axis)
1083        prj = ds.proj(fields[0], axis, weight_field=weight_field, **kwargs)
1084        w, frb, lunit = construct_image(
1085            ds, axis, prj, dcenter, image_res, width, length_unit
1086        )
1087        super().__init__(frb, fields=fields, length_unit=lunit, wcs=w)
1088
1089
1090class FITSParticleProjection(FITSImageData):
1091    r"""
1092    Generate a FITSImageData of an on-axis projection of a
1093    particle field.
1094
1095    Parameters
1096    ----------
1097    ds : :class:`~yt.data_objects.static_output.Dataset`
1098        The dataset object.
1099    axis : character or integer
1100        The axis along which to project. One of "x","y","z", or 0,1,2.
1101    fields : string or list of strings
1102        The fields to project
1103    image_res : an int or 2-tuple of ints
1104        Specify the resolution of the resulting image. A single value will be
1105        used for both axes, whereas a tuple of values will be used for the
1106        individual axes. Default: 512
1107    center : A sequence of floats, a string, or a tuple.
1108        The coordinate of the center of the image. If set to 'c', 'center' or
1109        left blank, the plot is centered on the middle of the domain. If set
1110        to 'max' or 'm', the center will be located at the maximum of the
1111        ('gas', 'density') field. Centering on the max or min of a specific
1112        field is supported by providing a tuple such as ("min","temperature")
1113        or ("max","dark_matter_density"). Units can be specified by passing in
1114        *center* as a tuple containing a coordinate and string unit name or by
1115        passing in a YTArray. If a list or unitless array is supplied, code
1116        units are assumed.
1117    width : tuple or a float.
1118        Width can have four different formats to support variable
1119        x and y widths.  They are:
1120
1121        ==================================     =======================
1122        format                                 example
1123        ==================================     =======================
1124        (float, string)                        (10,'kpc')
1125        ((float, string), (float, string))     ((10,'kpc'),(15,'kpc'))
1126        float                                  0.2
1127        (float, float)                         (0.2, 0.3)
1128        ==================================     =======================
1129
1130        For example, (10, 'kpc') specifies a width that is 10 kiloparsecs
1131        wide in the x and y directions, ((10,'kpc'),(15,'kpc')) specifies a
1132        width that is 10 kiloparsecs wide along the x axis and 15
1133        kiloparsecs wide along the y axis.  In the other two examples, code
1134        units are assumed, for example (0.2, 0.3) specifies a width that has an
1135        x width of 0.2 and a y width of 0.3 in code units.
1136    depth : A tuple or a float
1137         A tuple containing the depth to project through and the string
1138         key of the unit: (width, 'unit').  If set to a float, code units
1139         are assumed. Defaults to the entire domain.
1140    weight_field : string
1141        The field used to weight the projection.
1142    length_unit : string, optional
1143        the length units that the coordinates are written in. The default
1144        is to use the default length unit of the dataset.
1145    deposition : string, optional
1146        Controls the order of the interpolation of the particles onto the
1147        mesh. "ngp" is 0th-order "nearest-grid-point" method (the default),
1148        "cic" is 1st-order "cloud-in-cell".
1149    density : boolean, optional
1150        If True, the quantity to be projected will be divided by the area of
1151        the cells, to make a projected density of the quantity. Default: False
1152    field_parameters : dictionary
1153         A dictionary of field parameters than can be accessed by derived
1154         fields.
1155    data_source : yt.data_objects.data_containers.YTSelectionContainer, optional
1156        If specified, this will be the data source used for selecting regions
1157        to project.
1158    """
1159
1160    def __init__(
1161        self,
1162        ds,
1163        axis,
1164        fields,
1165        image_res=512,
1166        center="c",
1167        width=None,
1168        depth=(1, "1"),
1169        weight_field=None,
1170        length_unit=None,
1171        deposition="ngp",
1172        density=False,
1173        field_parameters=None,
1174        data_source=None,
1175    ):
1176        fields = list(iter_fields(fields))
1177        axis = fix_axis(axis, ds)
1178        center, dcenter = ds.coordinates.sanitize_center(center, axis)
1179        width = ds.coordinates.sanitize_width(axis, width, depth)
1180        width[-1].convert_to_units(width[0].units)
1181
1182        if field_parameters is None:
1183            field_parameters = {}
1184
1185        ps = ParticleAxisAlignedDummyDataSource(
1186            center,
1187            ds,
1188            axis,
1189            width,
1190            fields,
1191            weight_field,
1192            field_parameters=field_parameters,
1193            data_source=data_source,
1194            deposition=deposition,
1195            density=density,
1196        )
1197        w, frb, lunit = construct_image(
1198            ds, axis, ps, dcenter, image_res, width, length_unit
1199        )
1200        super().__init__(frb, fields=fields, length_unit=lunit, wcs=w)
1201
1202
1203class FITSOffAxisSlice(FITSImageData):
1204    r"""
1205    Generate a FITSImageData of an off-axis slice.
1206
1207    Parameters
1208    ----------
1209    ds : :class:`~yt.data_objects.static_output.Dataset`
1210        The dataset object.
1211    normal : a sequence of floats
1212        The vector normal to the projection plane.
1213    fields : string or list of strings
1214        The fields to slice
1215    image_res : an int or 2-tuple of ints
1216        Specify the resolution of the resulting image. A single value will be
1217        used for both axes, whereas a tuple of values will be used for the
1218        individual axes. Default: 512
1219    center : A sequence of floats, a string, or a tuple.
1220        The coordinate of the center of the image. If set to 'c', 'center' or
1221        left blank, the plot is centered on the middle of the domain. If set
1222        to 'max' or 'm', the center will be located at the maximum of the
1223        ('gas', 'density') field. Centering on the max or min of a specific
1224        field is supported by providing a tuple such as ("min","temperature")
1225        or ("max","dark_matter_density"). Units can be specified by passing in
1226        *center* as a tuple containing a coordinate and string unit name or by
1227        passing in a YTArray. If a list or unitless array is supplied, code
1228        units are assumed.
1229    width : tuple or a float.
1230        Width can have four different formats to support variable
1231        x and y widths.  They are:
1232
1233        ==================================     =======================
1234        format                                 example
1235        ==================================     =======================
1236        (float, string)                        (10,'kpc')
1237        ((float, string), (float, string))     ((10,'kpc'),(15,'kpc'))
1238        float                                  0.2
1239        (float, float)                         (0.2, 0.3)
1240        ==================================     =======================
1241
1242        For example, (10, 'kpc') specifies a width that is 10 kiloparsecs
1243        wide in the x and y directions, ((10,'kpc'),(15,'kpc')) specifies a
1244        width that is 10 kiloparsecs wide along the x axis and 15
1245        kiloparsecs wide along the y axis.  In the other two examples, code
1246        units are assumed, for example (0.2, 0.3) specifies a width that has an
1247        x width of 0.2 and a y width of 0.3 in code units.
1248    north_vector : a sequence of floats
1249        A vector defining the 'up' direction in the plot.  This
1250        option sets the orientation of the slicing plane.  If not
1251        set, an arbitrary grid-aligned north-vector is chosen.
1252    length_unit : string, optional
1253        the length units that the coordinates are written in. The default
1254        is to use the default length unit of the dataset.
1255    """
1256
1257    def __init__(
1258        self,
1259        ds,
1260        normal,
1261        fields,
1262        image_res=512,
1263        center="c",
1264        width=None,
1265        north_vector=None,
1266        length_unit=None,
1267    ):
1268        fields = list(iter_fields(fields))
1269        center, dcenter = ds.coordinates.sanitize_center(center, 4)
1270        cut = ds.cutting(normal, center, north_vector=north_vector)
1271        center = ds.arr([0.0] * 2, "code_length")
1272        w, frb, lunit = construct_image(
1273            ds, normal, cut, center, image_res, width, length_unit
1274        )
1275        super().__init__(frb, fields=fields, length_unit=lunit, wcs=w)
1276
1277
1278class FITSOffAxisProjection(FITSImageData):
1279    r"""
1280    Generate a FITSImageData of an off-axis projection.
1281
1282    Parameters
1283    ----------
1284    ds : :class:`~yt.data_objects.static_output.Dataset`
1285        This is the dataset object corresponding to the
1286        simulation output to be plotted.
1287    normal : a sequence of floats
1288        The vector normal to the projection plane.
1289    fields : string, list of strings
1290        The name of the field(s) to be plotted.
1291    image_res : an int or 2-tuple of ints
1292        Specify the resolution of the resulting image. A single value will be
1293        used for both axes, whereas a tuple of values will be used for the
1294        individual axes. Default: 512
1295    center : A sequence of floats, a string, or a tuple.
1296        The coordinate of the center of the image. If set to 'c', 'center' or
1297        left blank, the plot is centered on the middle of the domain. If set
1298        to 'max' or 'm', the center will be located at the maximum of the
1299        ('gas', 'density') field. Centering on the max or min of a specific
1300        field is supported by providing a tuple such as ("min","temperature")
1301        or ("max","dark_matter_density"). Units can be specified by passing in
1302        *center* as a tuple containing a coordinate and string unit name or by
1303        passing in a YTArray. If a list or unitless array is supplied, code
1304        units are assumed.
1305    width : tuple or a float.
1306        Width can have four different formats to support variable
1307        x and y widths.  They are:
1308
1309        ==================================     =======================
1310        format                                 example
1311        ==================================     =======================
1312        (float, string)                        (10,'kpc')
1313        ((float, string), (float, string))     ((10,'kpc'),(15,'kpc'))
1314        float                                  0.2
1315        (float, float)                         (0.2, 0.3)
1316        ==================================     =======================
1317
1318        For example, (10, 'kpc') specifies a width that is 10 kiloparsecs
1319        wide in the x and y directions, ((10,'kpc'),(15,'kpc')) specifies a
1320        width that is 10 kiloparsecs wide along the x axis and 15
1321        kiloparsecs wide along the y axis.  In the other two examples, code
1322        units are assumed, for example (0.2, 0.3) specifies a width that has an
1323        x width of 0.2 and a y width of 0.3 in code units.
1324    depth : A tuple or a float
1325        A tuple containing the depth to project through and the string
1326        key of the unit: (width, 'unit').  If set to a float, code units
1327        are assumed
1328    weight_field : string
1329         The name of the weighting field.  Set to None for no weight.
1330    north_vector : a sequence of floats
1331        A vector defining the 'up' direction in the plot.  This
1332        option sets the orientation of the slicing plane.  If not
1333        set, an arbitrary grid-aligned north-vector is chosen.
1334    method : string
1335        The method of projection.  Valid methods are:
1336
1337        "integrate" with no weight_field specified : integrate the requested
1338        field along the line of sight.
1339
1340        "integrate" with a weight_field specified : weight the requested
1341        field by the weighting field and integrate along the line of sight.
1342
1343        "sum" : This method is the same as integrate, except that it does not
1344        multiply by a path length when performing the integration, and is
1345        just a straight summation of the field along the given axis. WARNING:
1346        This should only be used for uniform resolution grid datasets, as other
1347        datasets may result in unphysical images.
1348    data_source : yt.data_objects.data_containers.YTSelectionContainer, optional
1349        If specified, this will be the data source used for selecting regions
1350        to project.
1351    length_unit : string, optional
1352        the length units that the coordinates are written in. The default
1353        is to use the default length unit of the dataset.
1354    """
1355
1356    def __init__(
1357        self,
1358        ds,
1359        normal,
1360        fields,
1361        center="c",
1362        width=(1.0, "unitary"),
1363        weight_field=None,
1364        image_res=512,
1365        data_source=None,
1366        north_vector=None,
1367        depth=(1.0, "unitary"),
1368        method="integrate",
1369        length_unit=None,
1370    ):
1371        fields = list(iter_fields(fields))
1372        center, dcenter = ds.coordinates.sanitize_center(center, 4)
1373        buf = {}
1374        width = ds.coordinates.sanitize_width(normal, width, depth)
1375        wd = tuple(el.in_units("code_length").v for el in width)
1376        if not is_sequence(image_res):
1377            image_res = (image_res, image_res)
1378        res = (image_res[0], image_res[1])
1379        if data_source is None:
1380            source = ds
1381        else:
1382            source = data_source
1383        for field in fields:
1384            buf[field] = off_axis_projection(
1385                source,
1386                center,
1387                normal,
1388                wd,
1389                res,
1390                field,
1391                north_vector=north_vector,
1392                method=method,
1393                weight=weight_field,
1394            ).swapaxes(0, 1)
1395        center = ds.arr([0.0] * 2, "code_length")
1396        w, not_an_frb, lunit = construct_image(
1397            ds, normal, buf, center, image_res, width, length_unit
1398        )
1399        super().__init__(buf, fields=fields, wcs=w, length_unit=lunit, ds=ds)
1400