1""" 2Description : Set DataSet module for lip images 3""" 4# Licensed to the Apache Software Foundation (ASF) under one 5# or more contributor license agreements. See the NOTICE file 6# distributed with this work for additional information 7# regarding copyright ownership. The ASF licenses this file 8# to you under the Apache License, Version 2.0 (the 9# "License"); you may not use this file except in compliance 10# with the License. You may obtain a copy of the License at 11# 12# http://www.apache.org/licenses/LICENSE-2.0 13# 14# Unless required by applicable law or agreed to in writing, 15# software distributed under the License is distributed on an 16# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 17# KIND, either express or implied. See the License for the 18# specific language governing permissions and limitations 19# under the License. 20 21import os 22import glob 23from mxnet import nd 24import mxnet.gluon.data.dataset as dataset 25from mxnet.gluon.data.vision.datasets import image 26from utils.align import Align 27 28# pylint: disable=too-many-instance-attributes, too-many-arguments 29class LipsDataset(dataset.Dataset): 30 """ 31 Description : DataSet class for lip images 32 """ 33 def __init__(self, root, align_root, flag=1, 34 mode='train', transform=None, seq_len=75): 35 assert mode in ['train', 'valid'] 36 self._root = os.path.expanduser(root) 37 self._align_root = align_root 38 self._flag = flag 39 self._transform = transform 40 self._exts = ['.jpg', '.jpeg', '.png'] 41 self._seq_len = seq_len 42 self._mode = mode 43 self._list_images(self._root) 44 45 def _list_images(self, root): 46 """ 47 Description : generate list for lip images 48 """ 49 self.labels = [] 50 self.items = [] 51 52 valid_unseen_sub_idx = [1, 2, 20, 22] 53 skip_sub_idx = [21] 54 55 if self._mode == 'train': 56 sub_idx = ['s' + str(i) for i in range(1, 35) \ 57 if i not in valid_unseen_sub_idx + skip_sub_idx] 58 elif self._mode == 'valid': 59 sub_idx = ['s' + str(i) for i in valid_unseen_sub_idx] 60 61 folder_path = [] 62 for i in sub_idx: 63 folder_path.extend(glob.glob(os.path.join(root, i, "*"))) 64 65 for folder in folder_path: 66 filename = glob.glob(os.path.join(folder, "*")) 67 if len(filename) != self._seq_len: 68 continue 69 filename.sort() 70 label = os.path.split(folder)[-1] 71 self.items.append((filename, label)) 72 73 def align_generation(self, file_nm, padding=75): 74 """ 75 Description : Align to lip position 76 """ 77 align = Align(self._align_root + '/' + file_nm + '.align') 78 return nd.array(align.sentence(padding)) 79 80 def __getitem__(self, idx): 81 img = list() 82 for image_name in self.items[idx][0]: 83 tmp_img = image.imread(image_name, self._flag) 84 if self._transform is not None: 85 tmp_img = self._transform(tmp_img) 86 img.append(tmp_img) 87 img = nd.stack(*img) 88 img = nd.transpose(img, (1, 0, 2, 3)) 89 label = self.align_generation(self.items[idx][1], 90 padding=self._seq_len) 91 return img, label 92 93 def __len__(self): 94 return len(self.items) 95