1"""Monodepth 2Digging Into Self-Supervised Monocular Depth Estimation, ICCV 2019 3https://arxiv.org/abs/1806.01260 4""" 5from mxnet.gluon import nn 6from mxnet.context import cpu 7 8from .resnet_encoder import ResnetEncoder 9from .pose_decoder import PoseDecoder 10 11 12class MonoDepth2PoseNet(nn.HybridBlock): 13 r"""Monodepth2 14 15 Parameters 16 ---------- 17 backbone : string 18 Pre-trained dilated backbone network type ('resnet18', 'resnet34', 'resnet50', 19 'resnet101' or 'resnet152'). 20 pretrained_base : bool or str 21 Refers to if the backbone is pretrained or not. If `True`, 22 model weights of a model that was trained on ImageNet is loaded. 23 num_input_images : int 24 The number of input sequences. 1 for depth encoder, larger than 1 for pose encoder. 25 (Default: 2) 26 num_input_features : int 27 The number of input feature maps from posenet encoder. (Default: 1) 28 num_frames_to_predict_for: int 29 The number of output pose between frames; If None, it equals num_input_features - 1. 30 (Default: 2) 31 stride: int 32 The stride number for Conv in pose decoder. (Default: 1) 33 34 Reference: 35 36 Clement Godard, Oisin Mac Aodha, Michael Firman, Gabriel Brostow. 37 "Digging Into Self-Supervised Monocular Depth Estimation." ICCV, 2019 38 39 Examples 40 -------- 41 >>> model = MonoDepth2PoseNet(backbone='resnet18', pretrained_base=True) 42 >>> print(model) 43 """ 44 # pylint: disable=unused-argument 45 def __init__(self, backbone, pretrained_base, num_input_images=2, num_input_features=1, 46 num_frames_to_predict_for=2, stride=1, ctx=cpu(), **kwargs): 47 super(MonoDepth2PoseNet, self).__init__() 48 49 with self.name_scope(): 50 self.encoder = ResnetEncoder(backbone, pretrained_base, 51 num_input_images=num_input_images, ctx=ctx) 52 if not pretrained_base: 53 self.encoder.initialize(ctx=ctx) 54 self.decoder = PoseDecoder(self.encoder.num_ch_enc, 55 num_input_features=num_input_features, 56 num_frames_to_predict_for=num_frames_to_predict_for, 57 stride=stride) 58 self.decoder.initialize(ctx=ctx) 59 60 def hybrid_forward(self, F, x): 61 # pylint: disable=unused-argument 62 features = [self.encoder(x)] 63 axisangle, translation = self.decoder(features) 64 65 return axisangle, translation 66 67 def demo(self, x): 68 return self.predict(x) 69 70 def predict(self, x): 71 features = [self.encoder.predict(x)] 72 axisangle, translation = self.decoder.predict(features) 73 74 return axisangle, translation 75 76 77def get_monodepth2posenet(backbone='resnet18', pretrained_base=True, num_input_images=2, 78 num_input_features=1, num_frames_to_predict_for=2, stride=1, 79 root='~/.mxnet/models', ctx=cpu(0), pretrained=False, 80 pretrained_model='kitti_stereo_640x192', **kwargs): 81 r"""Monodepth2 82 83 Parameters 84 ---------- 85 backbone : string 86 Pre-trained dilated backbone network type ('resnet18', 'resnet34', 'resnet50', 87 'resnet101' or 'resnet152'). 88 pretrained_base : bool or str 89 Refers to if the backbone is pretrained or not. If `True`, 90 model weights of a model that was trained on ImageNet is loaded. 91 num_input_images : int 92 The number of input sequences. 1 for depth encoder, larger than 1 for pose encoder. 93 (Default: 2) 94 num_input_features : int 95 The number of input feature maps from posenet encoder. (Default: 1) 96 num_frames_to_predict_for: int 97 The number of output pose between frames; If None, it equals num_input_features - 1. 98 (Default: 2) 99 stride: int 100 The stride number for Conv in pose decoder. (Default: 1) 101 102 ctx : Context, default: CPU 103 The context in which to load the pretrained weights. 104 root : str, default: '~/.mxnet/models' 105 Location for keeping the model parameters. 106 pretrained : bool or str, default: False 107 Boolean value controls whether to load the default pretrained weights for model. 108 String value represents the hashtag for a certain version of pretrained weights. 109 pretrained_model : string, default: kitti_stereo_640x192 110 The dataset that model pretrained on. 111 112 """ 113 114 model = MonoDepth2PoseNet( 115 backbone=backbone, pretrained_base=pretrained_base, 116 num_input_images=num_input_images, num_input_features=num_input_features, 117 num_frames_to_predict_for=num_frames_to_predict_for, stride=stride, 118 ctx=ctx, **kwargs) 119 120 if pretrained: 121 from ...model_zoo.model_store import get_model_file 122 model.load_parameters( 123 get_model_file('monodepth2_%s_%s' % (backbone, pretrained_model), 124 tag=pretrained, root=root), 125 ctx=ctx 126 ) 127 return model 128 129 130def get_monodepth2_resnet18_posenet_kitti_mono_640x192(**kwargs): 131 r"""Monodepth2 PoseNet 132 133 Parameters 134 ---------- 135 backbone : string 136 Pre-trained dilated backbone network type (default:'resnet18'). 137 138 """ 139 return get_monodepth2posenet(backbone='resnet18', 140 pretrained_model='posenet_kitti_mono_640x192', **kwargs) 141 142 143def get_monodepth2_resnet18_posenet_kitti_mono_stereo_640x192(**kwargs): 144 r"""Monodepth2 PoseNet 145 146 Parameters 147 ---------- 148 backbone : string 149 Pre-trained dilated backbone network type (default:'resnet18'). 150 151 """ 152 return get_monodepth2posenet(backbone='resnet18', 153 pretrained_model='posenet_kitti_mono_stereo_640x192', **kwargs) 154