1# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2#
3# This source code is licensed under the MIT license found in the
4# LICENSE file in the root directory of this source tree.
5
6import os
7import typing as tp
8import torch
9import numpy as np
10import imquality.brisque as brisque
11import lpips
12import cv2
13from nevergrad.functions.base import UnsupportedExperiment as UnsupportedExperiment
14from nevergrad.common.decorators import Registry
15
16
17registry: Registry[tp.Any] = Registry()
18MODELS: tp.Dict[str, tp.Any] = {}
19
20
21class ImageLoss:
22
23    REQUIRES_REFERENCE = True
24
25    def __init__(self, reference: tp.Optional[np.ndarray] = None) -> None:
26        if reference is not None:
27            self.reference = reference
28            assert len(self.reference.shape) == 3, self.reference.shape
29            assert self.reference.min() >= 0.0
30            assert self.reference.max() <= 256.0, f"Image max = {self.reference.max()}"
31            assert self.reference.max() > 3.0  # Not totally sure but entirely black images are not very cool.
32            self.domain_shape = self.reference.shape
33
34    def __call__(self, img: np.ndarray) -> float:
35        raise NotImplementedError(f"__call__ undefined in class {type(self)}")
36
37
38@registry.register
39class SumAbsoluteDifferences(ImageLoss):
40    def __call__(self, x: np.ndarray) -> float:
41        assert x.shape == self.domain_shape, f"Shape = {x.shape} vs {self.domain_shape}"
42        value = float(np.sum(np.fabs(x - self.reference)))
43        return value
44
45
46class Lpips(ImageLoss):
47    def __init__(self, reference: tp.Optional[np.ndarray] = None, net: str = "") -> None:
48        super().__init__(reference)
49        self.net = net
50
51    def __call__(self, img: np.ndarray) -> float:
52        if self.net not in MODELS:
53            MODELS[self.net] = lpips.LPIPS(net=self.net)
54        loss_fn = MODELS[self.net]
55        assert img.shape[2] == 3
56        assert len(img.shape) == 3
57        assert img.max() <= 256.0, f"Image max = {img.max()}"
58        assert img.min() >= 0.0
59        assert img.max() > 3.0
60        img0 = torch.clamp(torch.Tensor(img).unsqueeze(0).permute(0, 3, 1, 2) / 256.0, 0, 1) * 2.0 - 1.0
61        img1 = (
62            torch.clamp(torch.Tensor(self.reference.copy()).unsqueeze(0).permute(0, 3, 1, 2) / 256.0, 0, 1)
63            * 2.0
64            - 1.0
65        )  # The copy operation is here because of a warning otherwise, as Torch does not support non-writable numpy arrays.
66        return float(loss_fn(img0, img1))
67
68
69@registry.register
70class LpipsAlex(Lpips):
71    def __init__(self, reference: np.ndarray) -> None:
72        super().__init__(reference, net="alex")
73
74
75@registry.register
76class LpipsVgg(Lpips):
77    def __init__(self, reference: np.ndarray) -> None:
78        super().__init__(reference, net="vgg")
79
80
81@registry.register
82class SumSquareDifferences(ImageLoss):
83    def __call__(self, x: np.ndarray) -> float:
84        assert x.shape == self.domain_shape, f"Shape = {x.shape} vs {self.domain_shape}"
85        value = float(np.sum((x - self.reference) ** 2))
86        return value
87
88
89@registry.register
90class HistogramDifference(ImageLoss):
91    def __call__(self, x: np.ndarray) -> float:
92        assert x.shape == self.domain_shape, f"Shape = {x.shape} vs {self.domain_shape}"
93        assert x.shape[2] == 3
94        x_gray_1d = np.sum(x, 2).ravel()
95        ref_gray_1d = np.sum(self.reference, 2).ravel()
96        value = float(np.sum(np.sort(x_gray_1d) - np.sort(ref_gray_1d)))
97        return value
98
99
100@registry.register
101class Koncept512(ImageLoss):
102    """
103    This loss uses the neural network Koncept512 to score images
104    It takes one image or a list of images of shape [x, y, 3], with each pixel between 0 and 256, and returns a score.
105    """
106
107    REQUIRES_REFERENCE = False
108
109    @property
110    def koncept(self) -> tp.Any:  # cache the model
111        key = "koncept"
112        if key not in MODELS:
113            if os.name != "nt":
114                # pylint: disable=import-outside-toplevel
115                from koncept.models import Koncept512 as K512Model
116
117                MODELS[key] = K512Model()
118            else:
119                raise UnsupportedExperiment("Koncept512 is not working properly under Windows")
120        return MODELS[key]
121
122    def __call__(self, img: np.ndarray) -> float:
123        loss = -self.koncept.assess(img)
124        return float(loss)
125
126
127@registry.register
128class Blur(ImageLoss):
129    """
130    This estimates bluriness.
131    """
132
133    REQUIRES_REFERENCE = False
134
135    def __call__(self, img: np.ndarray) -> float:
136        assert img.shape[2] == 3
137        assert len(img.shape) == 3
138        img = np.asarray(img, dtype=np.float64)
139        return -float(cv2.Laplacian(img, cv2.CV_64F).var())
140
141
142@registry.register
143class Brisque(ImageLoss):
144    """
145    This estimates the Brisque score (lower is better).
146    """
147
148    REQUIRES_REFERENCE = False
149
150    def __call__(self, img: np.ndarray) -> float:
151        try:
152            score = brisque.score(img)
153        except AssertionError:  # oh my god, brisque can raise an assert when the data is too weird.
154            score = float("inf")
155        return score
156