1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18
19package org.apache.mxnetexamples.rnn
20
21import org.apache.mxnet.DType.DType
22import org.apache.mxnet._
23import org.slf4j.LoggerFactory
24
25import scala.collection.immutable.ListMap
26import scala.collection.mutable.ArrayBuffer
27import scala.io.Source
28import scala.util.Random
29import scala.collection.mutable
30
31object BucketIo {
32
33  type Text2Id = (String, Map[String, Int]) => Array[Int]
34  type ReadContent = String => String
35
36  def defaultReadContent(path: String): String = {
37    Source.fromFile(path, "UTF-8").mkString.replaceAll("\\. |\n", " <eos> ")
38  }
39
40  def defaultBuildVocab(path: String): Map[String, Int] = {
41    val content = defaultReadContent(path).split(" ")
42    var idx = 1 // 0 is left for zero - padding
43    val vocab = mutable.Map.empty[String, Int]
44    vocab.put(" ", 0) // put a dummy element here so that len (vocab) is correct
45    content.foreach(word =>
46      if (word.length > 0 && !vocab.contains(word)) {
47        vocab.put(word, idx)
48        idx += 1
49      }
50    )
51    vocab.toMap
52  }
53
54  def defaultText2Id(sentence: String, theVocab: Map[String, Int]): Array[Int] = {
55    val words = {
56      val tmp = sentence.split(" ").filter(_.length() > 0)
57      for (w <- tmp) yield theVocab(w)
58    }
59    words
60  }
61
62  def defaultGenBuckets(sentences: Array[String], batchSize: Int,
63                        theVocab: Map[String, Int]): IndexedSeq[Int] = {
64    val lenDict = scala.collection.mutable.Map[Int, Int]()
65    var maxLen = -1
66    for (sentence <- sentences) {
67      val wordsLen = defaultText2Id(sentence, theVocab).length
68      if (wordsLen > 0) {
69        if (wordsLen > maxLen) {
70          maxLen = wordsLen
71        }
72        if (lenDict.contains(wordsLen)) {
73          lenDict(wordsLen) = lenDict(wordsLen) + 1
74        } else {
75          lenDict += wordsLen -> 1
76        }
77      }
78    }
79
80    var tl = 0
81    val buckets = ArrayBuffer[Int]()
82    lenDict.foreach {
83      case (l, n) =>
84        if (n + tl >= batchSize) {
85          buckets.append(l)
86          tl = 0
87        } else tl += n
88    }
89    if (tl  > 0) buckets.append(maxLen)
90    buckets
91  }
92
93  class BucketSentenceIter(
94      path: String,
95      vocab: Map[String, Int],
96      var buckets: IndexedSeq[Int],
97      _batchSize: Int,
98      private val initStates: IndexedSeq[(String, (Int, Int))],
99      seperateChar: String = " <eos> ",
100      text2Id: Text2Id = defaultText2Id,
101      readContent: ReadContent = defaultReadContent) extends DataIter {
102    private val logger = LoggerFactory.getLogger(classOf[BucketSentenceIter])
103
104    private val content = readContent(path)
105    private val sentences = content.split(seperateChar)
106
107    if (buckets.length == 0) {
108      buckets = defaultGenBuckets(sentences, batchSize, vocab)
109    }
110    buckets = buckets.sorted
111    // pre-allocate with the largest bucket for better memory sharing
112    private val _defaultBucketKey = (buckets(0) /: buckets.drop(1)) { (max, elem) =>
113      if (max < elem) elem else max
114    }
115    override def defaultBucketKey: AnyRef = _defaultBucketKey.asInstanceOf[AnyRef]
116    // we just ignore the sentence it is longer than the maximum
117    // bucket size here
118    private val data = buckets.indices.map(x => Array[Array[Float]]()).toArray
119    for (sentence <- sentences) {
120      val ids = text2Id(sentence, vocab)
121      if (ids.length > 0) {
122        import scala.util.control.Breaks._
123        breakable { buckets.indices.foreach { idx =>
124          if (buckets(idx) >= ids.length) {
125            data(idx) = data(idx) :+
126            (ids.map(_.toFloat) ++ Array.fill[Float](buckets(idx) - ids.length)(0f))
127            break()
128          }
129        }}
130      }
131    }
132
133    // Get the size of each bucket, so that we could sample
134    // uniformly from the bucket
135    private val bucketSizes = data.map(_.length)
136    logger.info("Summary of dataset ==================")
137    buckets.zip(bucketSizes).foreach {
138      case (bkt, size) => logger.info(s"bucket of len $bkt : $size samples")
139    }
140
141     // make a random data iteration plan
142     // truncate each bucket into multiple of batch-size
143    private var bucketNBatches = Array[Int]()
144    for (i <- data.indices) {
145      bucketNBatches = bucketNBatches :+ (data(i).length / _batchSize)
146      data(i) = data(i).take(bucketNBatches(i) * _batchSize)
147    }
148
149    private val bucketPlan = {
150      val plan = bucketNBatches.zipWithIndex.map(x => Array.fill[Int](x._1)(x._2)).flatten
151      Random.shuffle(plan.toList).toArray
152    }
153
154    private val bucketIdxAll = data.map(_.length).map(l =>
155      Random.shuffle((0 until l).toList).toArray)
156    private val bucketCurrIdx = data.map(x => 0)
157
158    private val dataBuffer = ArrayBuffer[NDArray]()
159    private val labelBuffer = ArrayBuffer[NDArray]()
160    for (iBucket <- data.indices) {
161      dataBuffer.append(NDArray.zeros(_batchSize, buckets(iBucket)))
162      labelBuffer.append(NDArray.zeros(_batchSize, buckets(iBucket)))
163    }
164
165    private val _provideData = { val tmp = ListMap("data" -> Shape(_batchSize, _defaultBucketKey))
166      tmp ++ initStates.map(x => x._1 -> Shape(x._2._1, x._2._2))
167    }
168
169    private val _provideLabel = ListMap("softmax_label" -> Shape(_batchSize, _defaultBucketKey))
170
171    private val _provideDataDesc = {
172      // TODO: need to allow user to specify DType and Layout
173      val tmp = IndexedSeq(new DataDesc("data",
174        Shape(_batchSize, _defaultBucketKey), DType.Float32, Layout.UNDEFINED))
175      tmp ++ initStates.map(x => new DataDesc(x._1, Shape(x._2._1, x._2._2),
176        DType.Float32, Layout.UNDEFINED))
177    }
178
179    private val _provideLabelDesc = IndexedSeq(
180      // TODO: need to allow user to specify DType and Layout
181      new DataDesc("softmax_label",
182      Shape(_batchSize, _defaultBucketKey), DType.Float32, Layout.UNDEFINED))
183
184    private var iBucket = 0
185
186    override def next(): DataBatch = {
187      if (!hasNext) throw new NoSuchElementException
188      val bucketIdx = bucketPlan(iBucket)
189      val dataBuf = dataBuffer(bucketIdx)
190      val iIdx = bucketCurrIdx(bucketIdx)
191      val idx = bucketIdxAll(bucketIdx).slice(iIdx, iIdx + _batchSize)
192      bucketCurrIdx(bucketIdx) = bucketCurrIdx(bucketIdx) + _batchSize
193
194      val datas = idx.map(i => data(bucketIdx)(i))
195      for (sentence <- datas) {
196        require(sentence.length == buckets(bucketIdx))
197      }
198      dataBuf.set(datas.flatten)
199
200      val labelBuf = labelBuffer(bucketIdx)
201      val labels = idx.map(i => data(bucketIdx)(i).drop(1) :+ 0f)
202      labelBuf.set(labels.flatten)
203
204      iBucket += 1
205      val batchProvideData = IndexedSeq(DataDesc("data", dataBuf.shape, dataBuf.dtype)) ++
206        initStates.map {
207          case (name, shape) => DataDesc(name, Shape(shape._1, shape._2), DType.Float32)}
208      val batchProvideLabel = IndexedSeq(DataDesc("softmax_label", labelBuf.shape, labelBuf.dtype))
209      val initStateArrays = initStates.map(x => NDArray.zeros(x._2._1, x._2._2))
210      new DataBatch(IndexedSeq(dataBuf.copy()) ++ initStateArrays,
211        IndexedSeq(labelBuf.copy()),
212        getIndex(),
213        getPad(),
214        this.buckets(bucketIdx).asInstanceOf[AnyRef],
215        batchProvideData, batchProvideLabel)
216    }
217
218    /**
219     * reset the iterator
220     */
221    override def reset(): Unit = {
222      iBucket = 0
223      bucketCurrIdx.indices.foreach(i => bucketCurrIdx(i) = 0)
224    }
225
226    override def batchSize: Int = _batchSize
227
228    /**
229     * get data of current batch
230     * @return the data of current batch
231     */
232    override def getData(): IndexedSeq[NDArray] = IndexedSeq(dataBuffer(bucketPlan(iBucket)))
233
234    /**
235     * Get label of current batch
236     * @return the label of current batch
237     */
238    override def getLabel(): IndexedSeq[NDArray] = IndexedSeq(labelBuffer(bucketPlan(iBucket)))
239
240    /**
241     * the index of current batch
242     * @return
243     */
244    override def getIndex(): IndexedSeq[Long] = IndexedSeq[Long]()
245
246    /**
247      * get the number of padding examples
248      * in current batch
249      * @return number of padding examples in current batch
250      */
251    override def getPad(): Int = 0
252
253    // The name and shape of label provided by this iterator
254    @deprecated("Use provideLabelDesc instead", "1.3.0")
255    override def provideLabel: ListMap[String, Shape] = this._provideLabel
256
257    // The name and shape of data provided by this iterator
258    @deprecated("Use provideDataDesc instead", "1.3.0")
259    override def provideData: ListMap[String, Shape] = this._provideData
260
261    // Provide type:DataDesc of the data
262    override def provideDataDesc: IndexedSeq[DataDesc] = _provideDataDesc
263
264    // Provide type:DataDesc of the label
265    override def provideLabelDesc: IndexedSeq[DataDesc] = _provideLabelDesc
266
267    override def hasNext: Boolean = {
268      iBucket < bucketPlan.length
269    }
270  }
271}
272