1"""Monodepth
2Digging Into Self-Supervised Monocular Depth Estimation, ICCV 2019
3https://arxiv.org/abs/1806.01260
4"""
5from mxnet.gluon import nn
6from mxnet.context import cpu
7
8from .resnet_encoder import ResnetEncoder
9from .pose_decoder import PoseDecoder
10
11
12class MonoDepth2PoseNet(nn.HybridBlock):
13    r"""Monodepth2
14
15    Parameters
16    ----------
17    backbone : string
18        Pre-trained dilated backbone network type ('resnet18', 'resnet34', 'resnet50',
19        'resnet101' or 'resnet152').
20    pretrained_base : bool or str
21        Refers to if the backbone is pretrained or not. If `True`,
22        model weights of a model that was trained on ImageNet is loaded.
23    num_input_images : int
24        The number of input sequences. 1 for depth encoder, larger than 1 for pose encoder.
25        (Default: 2)
26    num_input_features : int
27        The number of input feature maps from posenet encoder. (Default: 1)
28    num_frames_to_predict_for: int
29        The number of output pose between frames; If None, it equals num_input_features - 1.
30        (Default: 2)
31    stride: int
32        The stride number for Conv in pose decoder. (Default: 1)
33
34    Reference:
35
36        Clement Godard, Oisin Mac Aodha, Michael Firman, Gabriel Brostow.
37        "Digging Into Self-Supervised Monocular Depth Estimation." ICCV, 2019
38
39    Examples
40    --------
41    >>> model = MonoDepth2PoseNet(backbone='resnet18', pretrained_base=True)
42    >>> print(model)
43    """
44    # pylint: disable=unused-argument
45    def __init__(self, backbone, pretrained_base, num_input_images=2, num_input_features=1,
46                 num_frames_to_predict_for=2, stride=1, ctx=cpu(), **kwargs):
47        super(MonoDepth2PoseNet, self).__init__()
48
49        with self.name_scope():
50            self.encoder = ResnetEncoder(backbone, pretrained_base,
51                                         num_input_images=num_input_images, ctx=ctx)
52            if not pretrained_base:
53                self.encoder.initialize(ctx=ctx)
54            self.decoder = PoseDecoder(self.encoder.num_ch_enc,
55                                       num_input_features=num_input_features,
56                                       num_frames_to_predict_for=num_frames_to_predict_for,
57                                       stride=stride)
58            self.decoder.initialize(ctx=ctx)
59
60    def hybrid_forward(self, F, x):
61        # pylint: disable=unused-argument
62        features = [self.encoder(x)]
63        axisangle, translation = self.decoder(features)
64
65        return axisangle, translation
66
67    def demo(self, x):
68        return self.predict(x)
69
70    def predict(self, x):
71        features = [self.encoder.predict(x)]
72        axisangle, translation = self.decoder.predict(features)
73
74        return axisangle, translation
75
76
77def get_monodepth2posenet(backbone='resnet18', pretrained_base=True, num_input_images=2,
78                          num_input_features=1, num_frames_to_predict_for=2, stride=1,
79                          root='~/.mxnet/models', ctx=cpu(0), pretrained=False,
80                          pretrained_model='kitti_stereo_640x192', **kwargs):
81    r"""Monodepth2
82
83    Parameters
84    ----------
85    backbone : string
86        Pre-trained dilated backbone network type ('resnet18', 'resnet34', 'resnet50',
87        'resnet101' or 'resnet152').
88    pretrained_base : bool or str
89        Refers to if the backbone is pretrained or not. If `True`,
90        model weights of a model that was trained on ImageNet is loaded.
91    num_input_images : int
92        The number of input sequences. 1 for depth encoder, larger than 1 for pose encoder.
93        (Default: 2)
94    num_input_features : int
95        The number of input feature maps from posenet encoder. (Default: 1)
96    num_frames_to_predict_for: int
97        The number of output pose between frames; If None, it equals num_input_features - 1.
98        (Default: 2)
99    stride: int
100        The stride number for Conv in pose decoder. (Default: 1)
101
102    ctx : Context, default: CPU
103        The context in which to load the pretrained weights.
104    root : str, default: '~/.mxnet/models'
105        Location for keeping the model parameters.
106    pretrained : bool or str, default: False
107        Boolean value controls whether to load the default pretrained weights for model.
108        String value represents the hashtag for a certain version of pretrained weights.
109    pretrained_model : string, default: kitti_stereo_640x192
110        The dataset that model pretrained on.
111
112    """
113
114    model = MonoDepth2PoseNet(
115        backbone=backbone, pretrained_base=pretrained_base,
116        num_input_images=num_input_images, num_input_features=num_input_features,
117        num_frames_to_predict_for=num_frames_to_predict_for, stride=stride,
118        ctx=ctx, **kwargs)
119
120    if pretrained:
121        from ...model_zoo.model_store import get_model_file
122        model.load_parameters(
123            get_model_file('monodepth2_%s_%s' % (backbone, pretrained_model),
124                           tag=pretrained, root=root),
125            ctx=ctx
126        )
127    return model
128
129
130def get_monodepth2_resnet18_posenet_kitti_mono_640x192(**kwargs):
131    r"""Monodepth2 PoseNet
132
133    Parameters
134    ----------
135    backbone : string
136        Pre-trained dilated backbone network type (default:'resnet18').
137
138    """
139    return get_monodepth2posenet(backbone='resnet18',
140                                 pretrained_model='posenet_kitti_mono_640x192', **kwargs)
141
142
143def get_monodepth2_resnet18_posenet_kitti_mono_stereo_640x192(**kwargs):
144    r"""Monodepth2 PoseNet
145
146    Parameters
147    ----------
148    backbone : string
149        Pre-trained dilated backbone network type (default:'resnet18').
150
151    """
152    return get_monodepth2posenet(backbone='resnet18',
153                                 pretrained_model='posenet_kitti_mono_stereo_640x192', **kwargs)
154