1"""
2Description : Set DataSet module for lip images
3"""
4# Licensed to the Apache Software Foundation (ASF) under one
5# or more contributor license agreements.  See the NOTICE file
6# distributed with this work for additional information
7# regarding copyright ownership.  The ASF licenses this file
8# to you under the Apache License, Version 2.0 (the
9# "License"); you may not use this file except in compliance
10# with the License.  You may obtain a copy of the License at
11#
12#   http://www.apache.org/licenses/LICENSE-2.0
13#
14# Unless required by applicable law or agreed to in writing,
15# software distributed under the License is distributed on an
16# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
17# KIND, either express or implied.  See the License for the
18# specific language governing permissions and limitations
19# under the License.
20
21import os
22import glob
23from mxnet import nd
24import mxnet.gluon.data.dataset as dataset
25from mxnet.gluon.data.vision.datasets import image
26from utils.align import Align
27
28# pylint: disable=too-many-instance-attributes, too-many-arguments
29class LipsDataset(dataset.Dataset):
30    """
31    Description : DataSet class for lip images
32    """
33    def __init__(self, root, align_root, flag=1,
34                 mode='train', transform=None, seq_len=75):
35        assert mode in ['train', 'valid']
36        self._root = os.path.expanduser(root)
37        self._align_root = align_root
38        self._flag = flag
39        self._transform = transform
40        self._exts = ['.jpg', '.jpeg', '.png']
41        self._seq_len = seq_len
42        self._mode = mode
43        self._list_images(self._root)
44
45    def _list_images(self, root):
46        """
47        Description : generate list for lip images
48        """
49        self.labels = []
50        self.items = []
51
52        valid_unseen_sub_idx = [1, 2, 20, 22]
53        skip_sub_idx = [21]
54
55        if self._mode == 'train':
56            sub_idx = ['s' + str(i) for i in range(1, 35) \
57                             if i not in valid_unseen_sub_idx + skip_sub_idx]
58        elif self._mode == 'valid':
59            sub_idx = ['s' + str(i) for i in valid_unseen_sub_idx]
60
61        folder_path = []
62        for i in sub_idx:
63            folder_path.extend(glob.glob(os.path.join(root, i, "*")))
64
65        for folder in folder_path:
66            filename = glob.glob(os.path.join(folder, "*"))
67            if len(filename) != self._seq_len:
68                continue
69            filename.sort()
70            label = os.path.split(folder)[-1]
71            self.items.append((filename, label))
72
73    def align_generation(self, file_nm, padding=75):
74        """
75        Description : Align to lip position
76        """
77        align = Align(self._align_root + '/' + file_nm + '.align')
78        return nd.array(align.sentence(padding))
79
80    def __getitem__(self, idx):
81        img = list()
82        for image_name in self.items[idx][0]:
83            tmp_img = image.imread(image_name, self._flag)
84            if self._transform is not None:
85                tmp_img = self._transform(tmp_img)
86            img.append(tmp_img)
87        img = nd.stack(*img)
88        img = nd.transpose(img, (1, 0, 2, 3))
89        label = self.align_generation(self.items[idx][1],
90                                      padding=self._seq_len)
91        return img, label
92
93    def __len__(self):
94        return len(self.items)
95