1# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2# 3# This source code is licensed under the MIT license found in the 4# LICENSE file in the root directory of this source tree. 5 6import os 7import typing as tp 8import torch 9import numpy as np 10import imquality.brisque as brisque 11import lpips 12import cv2 13from nevergrad.functions.base import UnsupportedExperiment as UnsupportedExperiment 14from nevergrad.common.decorators import Registry 15 16 17registry: Registry[tp.Any] = Registry() 18MODELS: tp.Dict[str, tp.Any] = {} 19 20 21class ImageLoss: 22 23 REQUIRES_REFERENCE = True 24 25 def __init__(self, reference: tp.Optional[np.ndarray] = None) -> None: 26 if reference is not None: 27 self.reference = reference 28 assert len(self.reference.shape) == 3, self.reference.shape 29 assert self.reference.min() >= 0.0 30 assert self.reference.max() <= 256.0, f"Image max = {self.reference.max()}" 31 assert self.reference.max() > 3.0 # Not totally sure but entirely black images are not very cool. 32 self.domain_shape = self.reference.shape 33 34 def __call__(self, img: np.ndarray) -> float: 35 raise NotImplementedError(f"__call__ undefined in class {type(self)}") 36 37 38@registry.register 39class SumAbsoluteDifferences(ImageLoss): 40 def __call__(self, x: np.ndarray) -> float: 41 assert x.shape == self.domain_shape, f"Shape = {x.shape} vs {self.domain_shape}" 42 value = float(np.sum(np.fabs(x - self.reference))) 43 return value 44 45 46class Lpips(ImageLoss): 47 def __init__(self, reference: tp.Optional[np.ndarray] = None, net: str = "") -> None: 48 super().__init__(reference) 49 self.net = net 50 51 def __call__(self, img: np.ndarray) -> float: 52 if self.net not in MODELS: 53 MODELS[self.net] = lpips.LPIPS(net=self.net) 54 loss_fn = MODELS[self.net] 55 assert img.shape[2] == 3 56 assert len(img.shape) == 3 57 assert img.max() <= 256.0, f"Image max = {img.max()}" 58 assert img.min() >= 0.0 59 assert img.max() > 3.0 60 img0 = torch.clamp(torch.Tensor(img).unsqueeze(0).permute(0, 3, 1, 2) / 256.0, 0, 1) * 2.0 - 1.0 61 img1 = ( 62 torch.clamp(torch.Tensor(self.reference.copy()).unsqueeze(0).permute(0, 3, 1, 2) / 256.0, 0, 1) 63 * 2.0 64 - 1.0 65 ) # The copy operation is here because of a warning otherwise, as Torch does not support non-writable numpy arrays. 66 return float(loss_fn(img0, img1)) 67 68 69@registry.register 70class LpipsAlex(Lpips): 71 def __init__(self, reference: np.ndarray) -> None: 72 super().__init__(reference, net="alex") 73 74 75@registry.register 76class LpipsVgg(Lpips): 77 def __init__(self, reference: np.ndarray) -> None: 78 super().__init__(reference, net="vgg") 79 80 81@registry.register 82class SumSquareDifferences(ImageLoss): 83 def __call__(self, x: np.ndarray) -> float: 84 assert x.shape == self.domain_shape, f"Shape = {x.shape} vs {self.domain_shape}" 85 value = float(np.sum((x - self.reference) ** 2)) 86 return value 87 88 89@registry.register 90class HistogramDifference(ImageLoss): 91 def __call__(self, x: np.ndarray) -> float: 92 assert x.shape == self.domain_shape, f"Shape = {x.shape} vs {self.domain_shape}" 93 assert x.shape[2] == 3 94 x_gray_1d = np.sum(x, 2).ravel() 95 ref_gray_1d = np.sum(self.reference, 2).ravel() 96 value = float(np.sum(np.sort(x_gray_1d) - np.sort(ref_gray_1d))) 97 return value 98 99 100@registry.register 101class Koncept512(ImageLoss): 102 """ 103 This loss uses the neural network Koncept512 to score images 104 It takes one image or a list of images of shape [x, y, 3], with each pixel between 0 and 256, and returns a score. 105 """ 106 107 REQUIRES_REFERENCE = False 108 109 @property 110 def koncept(self) -> tp.Any: # cache the model 111 key = "koncept" 112 if key not in MODELS: 113 if os.name != "nt": 114 # pylint: disable=import-outside-toplevel 115 from koncept.models import Koncept512 as K512Model 116 117 MODELS[key] = K512Model() 118 else: 119 raise UnsupportedExperiment("Koncept512 is not working properly under Windows") 120 return MODELS[key] 121 122 def __call__(self, img: np.ndarray) -> float: 123 loss = -self.koncept.assess(img) 124 return float(loss) 125 126 127@registry.register 128class Blur(ImageLoss): 129 """ 130 This estimates bluriness. 131 """ 132 133 REQUIRES_REFERENCE = False 134 135 def __call__(self, img: np.ndarray) -> float: 136 assert img.shape[2] == 3 137 assert len(img.shape) == 3 138 img = np.asarray(img, dtype=np.float64) 139 return -float(cv2.Laplacian(img, cv2.CV_64F).var()) 140 141 142@registry.register 143class Brisque(ImageLoss): 144 """ 145 This estimates the Brisque score (lower is better). 146 """ 147 148 REQUIRES_REFERENCE = False 149 150 def __call__(self, img: np.ndarray) -> float: 151 try: 152 score = brisque.score(img) 153 except AssertionError: # oh my god, brisque can raise an assert when the data is too weird. 154 score = float("inf") 155 return score 156