1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18 19package org.apache.mxnetexamples.rnn 20 21import org.apache.mxnet.DType.DType 22import org.apache.mxnet._ 23import org.slf4j.LoggerFactory 24 25import scala.collection.immutable.ListMap 26import scala.collection.mutable.ArrayBuffer 27import scala.io.Source 28import scala.util.Random 29import scala.collection.mutable 30 31object BucketIo { 32 33 type Text2Id = (String, Map[String, Int]) => Array[Int] 34 type ReadContent = String => String 35 36 def defaultReadContent(path: String): String = { 37 Source.fromFile(path, "UTF-8").mkString.replaceAll("\\. |\n", " <eos> ") 38 } 39 40 def defaultBuildVocab(path: String): Map[String, Int] = { 41 val content = defaultReadContent(path).split(" ") 42 var idx = 1 // 0 is left for zero - padding 43 val vocab = mutable.Map.empty[String, Int] 44 vocab.put(" ", 0) // put a dummy element here so that len (vocab) is correct 45 content.foreach(word => 46 if (word.length > 0 && !vocab.contains(word)) { 47 vocab.put(word, idx) 48 idx += 1 49 } 50 ) 51 vocab.toMap 52 } 53 54 def defaultText2Id(sentence: String, theVocab: Map[String, Int]): Array[Int] = { 55 val words = { 56 val tmp = sentence.split(" ").filter(_.length() > 0) 57 for (w <- tmp) yield theVocab(w) 58 } 59 words 60 } 61 62 def defaultGenBuckets(sentences: Array[String], batchSize: Int, 63 theVocab: Map[String, Int]): IndexedSeq[Int] = { 64 val lenDict = scala.collection.mutable.Map[Int, Int]() 65 var maxLen = -1 66 for (sentence <- sentences) { 67 val wordsLen = defaultText2Id(sentence, theVocab).length 68 if (wordsLen > 0) { 69 if (wordsLen > maxLen) { 70 maxLen = wordsLen 71 } 72 if (lenDict.contains(wordsLen)) { 73 lenDict(wordsLen) = lenDict(wordsLen) + 1 74 } else { 75 lenDict += wordsLen -> 1 76 } 77 } 78 } 79 80 var tl = 0 81 val buckets = ArrayBuffer[Int]() 82 lenDict.foreach { 83 case (l, n) => 84 if (n + tl >= batchSize) { 85 buckets.append(l) 86 tl = 0 87 } else tl += n 88 } 89 if (tl > 0) buckets.append(maxLen) 90 buckets 91 } 92 93 class BucketSentenceIter( 94 path: String, 95 vocab: Map[String, Int], 96 var buckets: IndexedSeq[Int], 97 _batchSize: Int, 98 private val initStates: IndexedSeq[(String, (Int, Int))], 99 seperateChar: String = " <eos> ", 100 text2Id: Text2Id = defaultText2Id, 101 readContent: ReadContent = defaultReadContent) extends DataIter { 102 private val logger = LoggerFactory.getLogger(classOf[BucketSentenceIter]) 103 104 private val content = readContent(path) 105 private val sentences = content.split(seperateChar) 106 107 if (buckets.length == 0) { 108 buckets = defaultGenBuckets(sentences, batchSize, vocab) 109 } 110 buckets = buckets.sorted 111 // pre-allocate with the largest bucket for better memory sharing 112 private val _defaultBucketKey = (buckets(0) /: buckets.drop(1)) { (max, elem) => 113 if (max < elem) elem else max 114 } 115 override def defaultBucketKey: AnyRef = _defaultBucketKey.asInstanceOf[AnyRef] 116 // we just ignore the sentence it is longer than the maximum 117 // bucket size here 118 private val data = buckets.indices.map(x => Array[Array[Float]]()).toArray 119 for (sentence <- sentences) { 120 val ids = text2Id(sentence, vocab) 121 if (ids.length > 0) { 122 import scala.util.control.Breaks._ 123 breakable { buckets.indices.foreach { idx => 124 if (buckets(idx) >= ids.length) { 125 data(idx) = data(idx) :+ 126 (ids.map(_.toFloat) ++ Array.fill[Float](buckets(idx) - ids.length)(0f)) 127 break() 128 } 129 }} 130 } 131 } 132 133 // Get the size of each bucket, so that we could sample 134 // uniformly from the bucket 135 private val bucketSizes = data.map(_.length) 136 logger.info("Summary of dataset ==================") 137 buckets.zip(bucketSizes).foreach { 138 case (bkt, size) => logger.info(s"bucket of len $bkt : $size samples") 139 } 140 141 // make a random data iteration plan 142 // truncate each bucket into multiple of batch-size 143 private var bucketNBatches = Array[Int]() 144 for (i <- data.indices) { 145 bucketNBatches = bucketNBatches :+ (data(i).length / _batchSize) 146 data(i) = data(i).take(bucketNBatches(i) * _batchSize) 147 } 148 149 private val bucketPlan = { 150 val plan = bucketNBatches.zipWithIndex.map(x => Array.fill[Int](x._1)(x._2)).flatten 151 Random.shuffle(plan.toList).toArray 152 } 153 154 private val bucketIdxAll = data.map(_.length).map(l => 155 Random.shuffle((0 until l).toList).toArray) 156 private val bucketCurrIdx = data.map(x => 0) 157 158 private val dataBuffer = ArrayBuffer[NDArray]() 159 private val labelBuffer = ArrayBuffer[NDArray]() 160 for (iBucket <- data.indices) { 161 dataBuffer.append(NDArray.zeros(_batchSize, buckets(iBucket))) 162 labelBuffer.append(NDArray.zeros(_batchSize, buckets(iBucket))) 163 } 164 165 private val _provideData = { val tmp = ListMap("data" -> Shape(_batchSize, _defaultBucketKey)) 166 tmp ++ initStates.map(x => x._1 -> Shape(x._2._1, x._2._2)) 167 } 168 169 private val _provideLabel = ListMap("softmax_label" -> Shape(_batchSize, _defaultBucketKey)) 170 171 private val _provideDataDesc = { 172 // TODO: need to allow user to specify DType and Layout 173 val tmp = IndexedSeq(new DataDesc("data", 174 Shape(_batchSize, _defaultBucketKey), DType.Float32, Layout.UNDEFINED)) 175 tmp ++ initStates.map(x => new DataDesc(x._1, Shape(x._2._1, x._2._2), 176 DType.Float32, Layout.UNDEFINED)) 177 } 178 179 private val _provideLabelDesc = IndexedSeq( 180 // TODO: need to allow user to specify DType and Layout 181 new DataDesc("softmax_label", 182 Shape(_batchSize, _defaultBucketKey), DType.Float32, Layout.UNDEFINED)) 183 184 private var iBucket = 0 185 186 override def next(): DataBatch = { 187 if (!hasNext) throw new NoSuchElementException 188 val bucketIdx = bucketPlan(iBucket) 189 val dataBuf = dataBuffer(bucketIdx) 190 val iIdx = bucketCurrIdx(bucketIdx) 191 val idx = bucketIdxAll(bucketIdx).slice(iIdx, iIdx + _batchSize) 192 bucketCurrIdx(bucketIdx) = bucketCurrIdx(bucketIdx) + _batchSize 193 194 val datas = idx.map(i => data(bucketIdx)(i)) 195 for (sentence <- datas) { 196 require(sentence.length == buckets(bucketIdx)) 197 } 198 dataBuf.set(datas.flatten) 199 200 val labelBuf = labelBuffer(bucketIdx) 201 val labels = idx.map(i => data(bucketIdx)(i).drop(1) :+ 0f) 202 labelBuf.set(labels.flatten) 203 204 iBucket += 1 205 val batchProvideData = IndexedSeq(DataDesc("data", dataBuf.shape, dataBuf.dtype)) ++ 206 initStates.map { 207 case (name, shape) => DataDesc(name, Shape(shape._1, shape._2), DType.Float32)} 208 val batchProvideLabel = IndexedSeq(DataDesc("softmax_label", labelBuf.shape, labelBuf.dtype)) 209 val initStateArrays = initStates.map(x => NDArray.zeros(x._2._1, x._2._2)) 210 new DataBatch(IndexedSeq(dataBuf.copy()) ++ initStateArrays, 211 IndexedSeq(labelBuf.copy()), 212 getIndex(), 213 getPad(), 214 this.buckets(bucketIdx).asInstanceOf[AnyRef], 215 batchProvideData, batchProvideLabel) 216 } 217 218 /** 219 * reset the iterator 220 */ 221 override def reset(): Unit = { 222 iBucket = 0 223 bucketCurrIdx.indices.foreach(i => bucketCurrIdx(i) = 0) 224 } 225 226 override def batchSize: Int = _batchSize 227 228 /** 229 * get data of current batch 230 * @return the data of current batch 231 */ 232 override def getData(): IndexedSeq[NDArray] = IndexedSeq(dataBuffer(bucketPlan(iBucket))) 233 234 /** 235 * Get label of current batch 236 * @return the label of current batch 237 */ 238 override def getLabel(): IndexedSeq[NDArray] = IndexedSeq(labelBuffer(bucketPlan(iBucket))) 239 240 /** 241 * the index of current batch 242 * @return 243 */ 244 override def getIndex(): IndexedSeq[Long] = IndexedSeq[Long]() 245 246 /** 247 * get the number of padding examples 248 * in current batch 249 * @return number of padding examples in current batch 250 */ 251 override def getPad(): Int = 0 252 253 // The name and shape of label provided by this iterator 254 @deprecated("Use provideLabelDesc instead", "1.3.0") 255 override def provideLabel: ListMap[String, Shape] = this._provideLabel 256 257 // The name and shape of data provided by this iterator 258 @deprecated("Use provideDataDesc instead", "1.3.0") 259 override def provideData: ListMap[String, Shape] = this._provideData 260 261 // Provide type:DataDesc of the data 262 override def provideDataDesc: IndexedSeq[DataDesc] = _provideDataDesc 263 264 // Provide type:DataDesc of the label 265 override def provideLabelDesc: IndexedSeq[DataDesc] = _provideLabelDesc 266 267 override def hasNext: Boolean = { 268 iBucket < bucketPlan.length 269 } 270 } 271} 272