1# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2#
3# This source code is licensed under the MIT license found in the
4# LICENSE file in the root directory of this source tree.
5
6import os
7import itertools
8from pathlib import Path
9
10import cv2
11import numpy as np
12import PIL.Image
13import torch.nn as nn
14import torch
15import torchvision
16from torchvision.models import resnet50
17import torchvision.transforms as tr
18
19import nevergrad as ng
20import nevergrad.common.typing as tp
21from nevergrad.common import errors
22from .. import base
23from . import imagelosses
24
25# pylint: disable=abstract-method
26
27
28class Image(base.ExperimentFunction):
29    def __init__(
30        self,
31        problem_name: str = "recovering",
32        index: int = 0,
33        loss: tp.Type[imagelosses.ImageLoss] = imagelosses.SumAbsoluteDifferences,
34        with_pgan: bool = False,
35        num_images: int = 1,
36    ) -> None:
37        """
38        problem_name: the type of problem we are working on.
39           recovering: we directly try to recover the target image.§
40        index: the index of the problem, inside the problem type.
41           For example, if problem_name is "recovering" and index == 0,
42           we try to recover the face of O. Teytaud.
43        """
44
45        # Storing high level information.
46        self.domain_shape = (226, 226, 3)
47        self.problem_name = problem_name
48        self.index = index
49        self.with_pgan = with_pgan
50        self.num_images = num_images
51
52        # Storing data necessary for the problem at hand.
53        assert problem_name == "recovering"  # For the moment we have only this one.
54        assert index == 0  # For the moment only 1 target.
55        # path = os.path.dirname(__file__) + "/headrgb_olivier.png"
56        path = Path(__file__).with_name("headrgb_olivier.png")
57        image = PIL.Image.open(path).resize((self.domain_shape[0], self.domain_shape[1]), PIL.Image.ANTIALIAS)
58        self.data = np.asarray(image)[:, :, :3]  # 4th Channel is pointless here, only 255.
59        # parametrization
60        if not with_pgan:
61            assert num_images == 1
62            array = ng.p.Array(init=128 * np.ones(self.domain_shape), mutable_sigma=True)
63            array.set_mutation(sigma=35)
64            array.set_bounds(lower=0, upper=255.99, method="clipping", full_range_sampling=True)
65            max_size = ng.p.Scalar(lower=1, upper=200).set_integer_casting()
66            array.set_recombination(ng.p.mutation.Crossover(axis=(0, 1), max_size=max_size)).set_name("")  # type: ignore
67            super().__init__(loss(reference=self.data), array)
68        else:
69            self.pgan_model = torch.hub.load(
70                "facebookresearch/pytorch_GAN_zoo:hub",
71                "PGAN",
72                model_name="celebAHQ-512",
73                pretrained=True,
74                useGPU=False,
75            )
76            self.domain_shape = (num_images, 512)  # type: ignore
77            initial_noise = np.random.normal(size=self.domain_shape)
78            self.initial = np.random.normal(size=(1, 512))
79            self.target = np.random.normal(size=(1, 512))
80            array = ng.p.Array(init=initial_noise, mutable_sigma=True)
81            array.set_mutation(sigma=35.0)
82            array.set_recombination(ng.p.mutation.Crossover(axis=(0, 1))).set_name("")
83            self._descriptors.pop("use_gpu", None)
84            super().__init__(self._loss_with_pgan, array)
85
86        assert self.multiobjective_upper_bounds is None
87        self.add_descriptors(loss=loss.__class__.__name__)
88        self.loss_function = loss(reference=self.data)
89
90    def _generate_images(self, x: np.ndarray) -> np.ndarray:
91        """ Generates images tensor of shape [nb_images, x, y, 3] with pixels between 0 and 255"""
92        # pylint: disable=not-callable
93        noise = torch.tensor(x.astype("float32"))
94        return ((self.pgan_model.test(noise).clamp(min=-1, max=1) + 1) * 255.99 / 2).permute(0, 2, 3, 1).cpu().numpy()[:, :, :, [2, 1, 0]]  # type: ignore
95
96    def interpolate(self, base_image: np.ndarray, target: np.ndarray, k: int, num_images: int) -> np.ndarray:
97        if num_images == 1:
98            return target
99        coef1 = k / (num_images - 1)
100        coef2 = (num_images - 1 - k) / (num_images - 1)
101        return coef1 * base_image + coef2 * target
102
103    def _loss_with_pgan(self, x: np.ndarray, export_string: str = "") -> float:
104        loss = 0.0
105        factor = 1 if self.num_images < 2 else 10  # Number of intermediate images.
106        num_total_images = factor * self.num_images
107        for i in range(num_total_images):
108            base_i = i // factor
109            # We generate num_images images. The last one is close to target, the first one is close to initial if num_images > 1.
110            base_image = self.interpolate(self.initial, self.target, i, num_total_images)
111            movability = 0.5  # If only one image, then we move by 0.5.
112            if self.num_images > 1:
113                movability = 4 * (
114                    0.25 - (i / (num_total_images - 1) - 0.5) ** 2
115                )  # 1 if i == num_total_images/2, 0 if 0 or num_images-1
116            moving = (
117                movability
118                * np.sqrt(self.dimension)
119                * np.expand_dims(x[base_i], 0)
120                / (1e-10 + np.linalg.norm(x[base_i]))
121            )
122            base_image = moving if self.num_images == 1 else base_image + moving
123            image = self._generate_images(base_image).squeeze(0)
124            image = cv2.resize(image, dsize=(226, 226), interpolation=cv2.INTER_NEAREST)
125            if export_string:
126                cv2.imwrite(f"{export_string}_image{i}_{num_total_images}_{self.num_images}.jpg", image)
127            assert image.shape == (226, 226, 3), f"{x.shape} != {(226, 226, 3)}"
128            loss += self.loss_function(image)
129        return loss
130
131    def export_to_images(self, x: np.ndarray, export_string: str = "export"):
132        self._loss_with_pgan(x, export_string=export_string)
133
134
135# #### Adversarial attacks ##### #
136
137
138class Normalize(nn.Module):
139    def __init__(self, mean: tp.ArrayLike, std: tp.ArrayLike) -> None:
140        super().__init__()
141        self.mean = torch.Tensor(mean)
142        self.std = torch.Tensor(std)
143
144    def forward(self, x: torch.Tensor) -> torch.Tensor:
145        return (x - self.mean.type_as(x)[None, :, None, None]) / self.std.type_as(x)[None, :, None, None]
146
147
148class Resnet50(nn.Module):
149    def __init__(self) -> None:
150        super().__init__()
151        self.norm = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
152        self.model = resnet50(pretrained=True)
153
154    def forward(self, x: torch.Tensor) -> torch.Tensor:
155        return self.model(self.norm(x))
156
157
158class TestClassifier(nn.Module):
159    def __init__(self, image_size: int = 224) -> None:
160        super().__init__()
161        self.model = nn.Linear(image_size * image_size * 3, 10)
162
163    def forward(self, x: torch.Tensor) -> torch.Tensor:
164        return self.model(x.view(x.shape[0], -1))
165
166
167# pylint: disable=too-many-arguments,too-many-instance-attributes
168class ImageAdversarial(base.ExperimentFunction):
169    def __init__(
170        self,
171        classifier: nn.Module,
172        image: torch.Tensor,
173        label: int = 0,
174        targeted: bool = False,
175        epsilon: float = 0.05,
176    ) -> None:
177        # TODO add crossover params in args + criterion
178        """
179        params : needs to be detailed
180        """
181        self.targeted = targeted
182        self.epsilon = epsilon
183        self.image = image  # if (image is not None) else torch.rand((3, 224, 224))
184        self.label = torch.Tensor([label])  # if (label is not None) else torch.Tensor([0])
185        self.label = self.label.long()
186        self.classifier = classifier  # if (classifier is not None) else Classifier()
187        self.criterion = nn.CrossEntropyLoss()
188        self.imsize = self.image.shape[1]
189
190        array = ng.p.Array(
191            init=np.zeros(self.image.shape),
192            mutable_sigma=True,
193        ).set_name("")
194        array.set_mutation(sigma=self.epsilon / 10)
195        array.set_bounds(lower=-self.epsilon, upper=self.epsilon, method="clipping", full_range_sampling=True)
196        max_size = ng.p.Scalar(lower=1, upper=200).set_integer_casting()
197        array.set_recombination(ng.p.mutation.Crossover(axis=(1, 2), max_size=max_size))  # type: ignore
198        super().__init__(self._loss, array)
199
200    def _loss(self, x: np.ndarray) -> float:
201        output_adv = self._get_classifier_output(x)
202        value = float(self.criterion(output_adv, self.label).item())
203        return value * (1.0 if self.targeted else -1.0)
204
205    def _get_classifier_output(self, x: np.ndarray) -> tp.Any:
206        # call to the classifier given the input array
207        y = torch.Tensor(x)
208        image_adv = torch.clamp(self.image + y, 0, 1)
209        image_adv = image_adv.view(1, 3, self.imsize, self.imsize)
210        return self.classifier(image_adv)
211
212    def evaluation_function(self, *recommendations: ng.p.Parameter) -> float:
213        """Returns wether the attack worked or not"""
214        assert len(recommendations) == 1, "Should not be a pareto set for a singleobjective function"
215        x = recommendations[0].value
216        output_adv = self._get_classifier_output(x)
217        _, pred = torch.max(output_adv, axis=1)
218        actual = int(self.label)
219        return float(pred == actual if self.targeted else pred != actual)
220
221    @classmethod
222    def make_folder_functions(
223        cls,
224        folder: tp.Optional[tp.PathLike],
225        model: str = "resnet50",
226    ) -> tp.Generator["ImageAdversarial", None, None]:
227        """
228
229        Parameters
230        ----------
231        folder: str or None
232            folder to use for reference images. If None, 1 random image is created.
233        model: str
234            model name to use
235
236        Yields
237        ------
238        ExperimentFunction
239            an experiment function corresponding to 1 of the image of the provided folder dataset.
240        """
241        assert model in {"resnet50", "test"}
242        tags = {"folder": "#FAKE#" if folder is None else Path(folder).name, "model": model}
243        classifier: tp.Any = Resnet50() if model == "resnet50" else TestClassifier()
244        imsize = 224
245        transform = tr.Compose([tr.Resize(imsize), tr.CenterCrop(imsize), tr.ToTensor()])
246        if folder is None:
247            x = torch.zeros(1, 3, 224, 224)
248            _, pred = torch.max(classifier(x), axis=1)
249            data_loader: tp.Iterable[tp.Tuple[tp.Any, tp.Any]] = [(x, pred)]
250        elif Path(folder).is_dir():
251            ifolder = torchvision.datasets.ImageFolder(folder, transform)
252            data_loader = torch.utils.DataLoader(
253                ifolder, batch_size=1, shuffle=True, num_workers=8, pin_memory=True
254            )
255        else:
256            raise ValueError(f"{folder} is not a valid folder.")
257        for data, target in itertools.islice(data_loader, 0, 100):
258            _, pred = torch.max(classifier(data), axis=1)
259            if pred == target:
260                func = cls(
261                    classifier=classifier, image=data[0], label=int(target), targeted=False, epsilon=0.05
262                )
263                func.add_descriptors(**tags)
264                yield func
265
266
267class ImageFromPGAN(base.ExperimentFunction):
268    """
269    Creates face images using a GAN from pytorch GAN zoo trained on celebAHQ and optimizes the noise vector of the GAN
270
271    Parameters
272    ----------
273    problem_name: str
274        the type of problem we are working on.
275    initial_noise: np.ndarray
276        the initial noise of the GAN. It should be of dimension (1, 512). If None, it is defined randomly.
277    use_gpu: bool
278        whether to use gpus to compute the images
279    loss: ImageLoss
280        which loss to use for the images (default: Koncept512)
281    mutable_sigma: bool
282        whether the sigma should be mutable
283    sigma: float
284        standard deviation of the initial mutations
285    """
286
287    def __init__(
288        self,
289        initial_noise: tp.Optional[np.ndarray] = None,
290        use_gpu: bool = False,
291        loss: tp.Optional[imagelosses.ImageLoss] = None,
292        mutable_sigma: bool = True,
293        sigma: float = 35,
294    ) -> None:
295        if loss is None:
296            loss = imagelosses.Koncept512()
297        if not torch.cuda.is_available():
298            use_gpu = False
299        # Storing high level information..
300        if os.environ.get("CIRCLECI", False):
301            raise errors.UnsupportedExperiment("ImageFromPGAN is not well supported in CircleCI")
302        self.pgan_model = torch.hub.load(
303            "facebookresearch/pytorch_GAN_zoo:hub",
304            "PGAN",
305            model_name="celebAHQ-512",
306            pretrained=True,
307            useGPU=use_gpu,
308        )
309
310        self.domain_shape = (1, 512)
311        if initial_noise is None:
312            initial_noise = np.random.normal(size=self.domain_shape)
313        assert initial_noise.shape == self.domain_shape, (
314            f"The shape of the initial noise vector was {initial_noise.shape}, "
315            f"it should be {self.domain_shape}"
316        )
317
318        array = ng.p.Array(init=initial_noise, mutable_sigma=mutable_sigma)
319        # parametrization
320        array.set_mutation(sigma=sigma)
321        array.set_recombination(ng.p.mutation.Crossover(axis=(0, 1))).set_name("")
322
323        super().__init__(self._loss, array)
324        self.loss_function = loss
325        self._descriptors.pop("use_gpu", None)
326
327        self.add_descriptors(loss=loss.__class__.__name__)
328
329    def _loss(self, x: np.ndarray) -> float:
330        image = self._generate_images(x)
331        loss = self.loss_function(image)
332        return loss
333
334    def _generate_images(self, x: np.ndarray) -> np.ndarray:
335        """ Generates images tensor of shape [nb_images, x, y, 3] with pixels between 0 and 255"""
336        # pylint: disable=not-callable
337        noise = torch.tensor(x.astype("float32"))
338        return ((self.pgan_model.test(noise).clamp(min=-1, max=1) + 1) * 255.99 / 2).permute(0, 2, 3, 1).cpu().numpy()  # type: ignore
339