1{
2 "cells": [
3  {
4   "cell_type": "markdown",
5   "metadata": {},
6   "source": [
7    "# Using a bi-lstm to sort a sequence of integers"
8   ]
9  },
10  {
11   "cell_type": "code",
12   "execution_count": 1,
13   "metadata": {},
14   "outputs": [],
15   "source": [
16    "import random\n",
17    "import string\n",
18    "\n",
19    "import mxnet as mx\n",
20    "from mxnet import gluon, nd\n",
21    "import numpy as np"
22   ]
23  },
24  {
25   "cell_type": "markdown",
26   "metadata": {},
27   "source": [
28    "## Data Preparation"
29   ]
30  },
31  {
32   "cell_type": "code",
33   "execution_count": 2,
34   "metadata": {},
35   "outputs": [],
36   "source": [
37    "max_num = 999\n",
38    "dataset_size = 60000\n",
39    "seq_len = 5\n",
40    "split = 0.8\n",
41    "batch_size = 512\n",
42    "ctx = mx.gpu() if mx.context.num_gpus() > 0 else mx.cpu()"
43   ]
44  },
45  {
46   "cell_type": "markdown",
47   "metadata": {},
48   "source": [
49    "We are getting a dataset of **dataset_size** sequences of integers of length **seq_len** between **0** and **max_num**. We use **split*100%** of them for training and the rest for testing.\n",
50    "\n",
51    "\n",
52    "For example:\n",
53    "\n",
54    "50 10 200 999 30\n",
55    "\n",
56    "Should return\n",
57    "\n",
58    "10 30 50 200 999"
59   ]
60  },
61  {
62   "cell_type": "code",
63   "execution_count": 3,
64   "metadata": {},
65   "outputs": [],
66   "source": [
67    "X = mx.random.uniform(low=0, high=max_num, shape=(dataset_size, seq_len)).astype('int32').asnumpy()\n",
68    "Y = X.copy()\n",
69    "Y.sort() #Let's sort X to get the target"
70   ]
71  },
72  {
73   "cell_type": "code",
74   "execution_count": 4,
75   "metadata": {},
76   "outputs": [
77    {
78     "name": "stdout",
79     "output_type": "stream",
80     "text": [
81      "Input [548, 592, 714, 843, 602]\n",
82      "Target [548, 592, 602, 714, 843]\n"
83     ]
84    }
85   ],
86   "source": [
87    "print(\"Input {}\\nTarget {}\".format(X[0].tolist(), Y[0].tolist()))"
88   ]
89  },
90  {
91   "cell_type": "markdown",
92   "metadata": {},
93   "source": [
94    "For the purpose of training, we encode the input as characters rather than numbers"
95   ]
96  },
97  {
98   "cell_type": "code",
99   "execution_count": 5,
100   "metadata": {},
101   "outputs": [
102    {
103     "name": "stdout",
104     "output_type": "stream",
105     "text": [
106      "0123456789 \n",
107      "{'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9, ' ': 10}\n"
108     ]
109    }
110   ],
111   "source": [
112    "vocab = string.digits + \" \"\n",
113    "print(vocab)\n",
114    "vocab_idx = { c:i for i,c in enumerate(vocab)}\n",
115    "print(vocab_idx)"
116   ]
117  },
118  {
119   "cell_type": "markdown",
120   "metadata": {},
121   "source": [
122    "We write a transform that will convert our numbers into text of maximum length **max_len**, and one-hot encode the characters.\n",
123    "For example:\n",
124    "\n",
125    "\"30 10\" corresponding indices are [3, 0, 10, 1, 0]\n",
126    "\n",
127    "We then one hot encode that and get a matrix representation of our input. We don't need to encode our target as the loss we are going to use support sparse labels"
128   ]
129  },
130  {
131   "cell_type": "code",
132   "execution_count": 6,
133   "metadata": {},
134   "outputs": [
135    {
136     "name": "stdout",
137     "output_type": "stream",
138     "text": [
139      "Maximum length of the string: 19\n"
140     ]
141    }
142   ],
143   "source": [
144    "max_len = len(str(max_num))*seq_len+(seq_len-1)\n",
145    "print(\"Maximum length of the string: %s\" % max_len)"
146   ]
147  },
148  {
149   "cell_type": "code",
150   "execution_count": 7,
151   "metadata": {},
152   "outputs": [],
153   "source": [
154    "def transform(x, y):\n",
155    "    x_string = ' '.join(map(str, x.tolist()))\n",
156    "    x_string_padded = x_string + ' '*(max_len-len(x_string))\n",
157    "    x = [vocab_idx[c] for c in x_string_padded]\n",
158    "    y_string = ' '.join(map(str, y.tolist()))\n",
159    "    y_string_padded = y_string + ' '*(max_len-len(y_string))\n",
160    "    y = [vocab_idx[c] for c in y_string_padded]\n",
161    "    return mx.nd.one_hot(mx.nd.array(x), len(vocab)), mx.nd.array(y)"
162   ]
163  },
164  {
165   "cell_type": "code",
166   "execution_count": 8,
167   "metadata": {},
168   "outputs": [],
169   "source": [
170    "split_idx = int(split*len(X))\n",
171    "train_dataset = gluon.data.ArrayDataset(X[:split_idx], Y[:split_idx]).transform(transform)\n",
172    "test_dataset = gluon.data.ArrayDataset(X[split_idx:], Y[split_idx:]).transform(transform)"
173   ]
174  },
175  {
176   "cell_type": "code",
177   "execution_count": 9,
178   "metadata": {},
179   "outputs": [
180    {
181     "name": "stdout",
182     "output_type": "stream",
183     "text": [
184      "Input [548 592 714 843 602]\n",
185      "Transformed data Input \n",
186      "[[0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n",
187      " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n",
188      " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n",
189      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
190      " [0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n",
191      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n",
192      " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
193      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
194      " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n",
195      " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
196      " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n",
197      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
198      " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n",
199      " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n",
200      " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n",
201      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
202      " [0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n",
203      " [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
204      " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]]\n",
205      "<NDArray 19x11 @cpu(0)>\n",
206      "Target [548 592 602 714 843]\n",
207      "Transformed data Target \n",
208      "[ 5.  4.  8. 10.  5.  9.  2. 10.  6.  0.  2. 10.  7.  1.  4. 10.  8.  4.\n",
209      "  3.]\n",
210      "<NDArray 19 @cpu(0)>\n"
211     ]
212    }
213   ],
214   "source": [
215    "print(\"Input {}\".format(X[0]))\n",
216    "print(\"Transformed data Input {}\".format(train_dataset[0][0]))\n",
217    "print(\"Target {}\".format(Y[0]))\n",
218    "print(\"Transformed data Target {}\".format(train_dataset[0][1]))"
219   ]
220  },
221  {
222   "cell_type": "code",
223   "execution_count": 10,
224   "metadata": {},
225   "outputs": [],
226   "source": [
227    "train_data = gluon.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=20, last_batch='rollover')\n",
228    "test_data = gluon.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=5, last_batch='rollover')"
229   ]
230  },
231  {
232   "cell_type": "markdown",
233   "metadata": {},
234   "source": [
235    "## Creating the network"
236   ]
237  },
238  {
239   "cell_type": "code",
240   "execution_count": 11,
241   "metadata": {},
242   "outputs": [],
243   "source": [
244    "net = gluon.nn.HybridSequential()\n",
245    "with net.name_scope():\n",
246    "    net.add(\n",
247    "        gluon.rnn.LSTM(hidden_size=128, num_layers=2, layout='NTC', bidirectional=True),\n",
248    "        gluon.nn.Dense(len(vocab), flatten=False)\n",
249    "    )"
250   ]
251  },
252  {
253   "cell_type": "code",
254   "execution_count": 12,
255   "metadata": {},
256   "outputs": [],
257   "source": [
258    "net.initialize(mx.init.Xavier(), ctx=ctx)"
259   ]
260  },
261  {
262   "cell_type": "code",
263   "execution_count": 13,
264   "metadata": {},
265   "outputs": [],
266   "source": [
267    "loss = gluon.loss.SoftmaxCELoss()"
268   ]
269  },
270  {
271   "cell_type": "markdown",
272   "metadata": {},
273   "source": [
274    "We use a learning rate schedule to improve the convergence of the model"
275   ]
276  },
277  {
278   "cell_type": "code",
279   "execution_count": 14,
280   "metadata": {},
281   "outputs": [],
282   "source": [
283    "schedule = mx.lr_scheduler.FactorScheduler(step=len(train_data)*10, factor=0.75)\n",
284    "schedule.base_lr = 0.01"
285   ]
286  },
287  {
288   "cell_type": "code",
289   "execution_count": 15,
290   "metadata": {},
291   "outputs": [],
292   "source": [
293    "trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate':0.01, 'lr_scheduler':schedule})"
294   ]
295  },
296  {
297   "cell_type": "markdown",
298   "metadata": {},
299   "source": [
300    "## Training loop"
301   ]
302  },
303  {
304   "cell_type": "code",
305   "execution_count": 16,
306   "metadata": {},
307   "outputs": [
308    {
309     "name": "stdout",
310     "output_type": "stream",
311     "text": [
312      "Epoch [0] Loss: 1.6627886372227823, LR 0.01\n",
313      "Epoch [1] Loss: 1.210370733382854, LR 0.01\n",
314      "Epoch [2] Loss: 0.9692377131035987, LR 0.01\n",
315      "Epoch [3] Loss: 0.7976046623067653, LR 0.01\n",
316      "Epoch [4] Loss: 0.5714595343476983, LR 0.01\n",
317      "Epoch [5] Loss: 0.4458411196444897, LR 0.01\n",
318      "Epoch [6] Loss: 0.36039798817736035, LR 0.01\n",
319      "Epoch [7] Loss: 0.32665719377233626, LR 0.01\n",
320      "Epoch [8] Loss: 0.262064205702915, LR 0.01\n",
321      "Epoch [9] Loss: 0.22285924059279422, LR 0.0075\n",
322      "Epoch [10] Loss: 0.19018426854559717, LR 0.0075\n",
323      "Epoch [11] Loss: 0.1718730723604243, LR 0.0075\n",
324      "Epoch [12] Loss: 0.15736752171670237, LR 0.0075\n",
325      "Epoch [13] Loss: 0.14579375246737866, LR 0.0075\n",
326      "Epoch [14] Loss: 0.13546599733068587, LR 0.0075\n",
327      "Epoch [15] Loss: 0.12490207590955368, LR 0.0075\n",
328      "Epoch [16] Loss: 0.11803316300915133, LR 0.0075\n",
329      "Epoch [17] Loss: 0.10653189395336395, LR 0.0075\n",
330      "Epoch [18] Loss: 0.10514750379197141, LR 0.0075\n",
331      "Epoch [19] Loss: 0.09590611559279422, LR 0.005625\n",
332      "Epoch [20] Loss: 0.08146028108494256, LR 0.005625\n",
333      "Epoch [21] Loss: 0.07707348782965477, LR 0.005625\n",
334      "Epoch [22] Loss: 0.07206193436967566, LR 0.005625\n",
335      "Epoch [23] Loss: 0.07001185417175293, LR 0.005625\n",
336      "Epoch [24] Loss: 0.06797058351578252, LR 0.005625\n",
337      "Epoch [25] Loss: 0.0649358110224947, LR 0.005625\n",
338      "Epoch [26] Loss: 0.06219124286732775, LR 0.005625\n",
339      "Epoch [27] Loss: 0.06075144828634059, LR 0.005625\n",
340      "Epoch [28] Loss: 0.05711334495134251, LR 0.005625\n",
341      "Epoch [29] Loss: 0.054747099572039666, LR 0.00421875\n",
342      "Epoch [30] Loss: 0.0441775271233092, LR 0.00421875\n",
343      "Epoch [31] Loss: 0.041551097910454936, LR 0.00421875\n",
344      "Epoch [32] Loss: 0.04095017269093503, LR 0.00421875\n",
345      "Epoch [33] Loss: 0.04045371045457556, LR 0.00421875\n",
346      "Epoch [34] Loss: 0.038867686657195394, LR 0.00421875\n",
347      "Epoch [35] Loss: 0.038131744303601854, LR 0.00421875\n",
348      "Epoch [36] Loss: 0.039834817250569664, LR 0.00421875\n",
349      "Epoch [37] Loss: 0.03669035941996473, LR 0.00421875\n",
350      "Epoch [38] Loss: 0.03373505967728635, LR 0.00421875\n",
351      "Epoch [39] Loss: 0.03164981273894615, LR 0.0031640625\n",
352      "Epoch [40] Loss: 0.025532766055035336, LR 0.0031640625\n",
353      "Epoch [41] Loss: 0.022659448867148543, LR 0.0031640625\n",
354      "Epoch [42] Loss: 0.02307056112492338, LR 0.0031640625\n",
355      "Epoch [43] Loss: 0.02236944056571798, LR 0.0031640625\n",
356      "Epoch [44] Loss: 0.022204211963120328, LR 0.0031640625\n",
357      "Epoch [45] Loss: 0.02262336903430046, LR 0.0031640625\n",
358      "Epoch [46] Loss: 0.02253308448385685, LR 0.0031640625\n",
359      "Epoch [47] Loss: 0.025286573044797207, LR 0.0031640625\n",
360      "Epoch [48] Loss: 0.02439300988310127, LR 0.0031640625\n",
361      "Epoch [49] Loss: 0.017976388018181983, LR 0.002373046875\n",
362      "Epoch [50] Loss: 0.014343131095805067, LR 0.002373046875\n",
363      "Epoch [51] Loss: 0.013039355582379281, LR 0.002373046875\n",
364      "Epoch [52] Loss: 0.011884741885687715, LR 0.002373046875\n",
365      "Epoch [53] Loss: 0.011438189668858305, LR 0.002373046875\n",
366      "Epoch [54] Loss: 0.011447292693117832, LR 0.002373046875\n",
367      "Epoch [55] Loss: 0.014212571560068334, LR 0.002373046875\n",
368      "Epoch [56] Loss: 0.019900493724371797, LR 0.002373046875\n",
369      "Epoch [57] Loss: 0.02102568301748722, LR 0.002373046875\n",
370      "Epoch [58] Loss: 0.01346214400961044, LR 0.002373046875\n",
371      "Epoch [59] Loss: 0.010107964911359422, LR 0.0017797851562500002\n",
372      "Epoch [60] Loss: 0.008353193600972494, LR 0.0017797851562500002\n",
373      "Epoch [61] Loss: 0.007678258292218472, LR 0.0017797851562500002\n",
374      "Epoch [62] Loss: 0.007262124660167288, LR 0.0017797851562500002\n",
375      "Epoch [63] Loss: 0.00705223578087827, LR 0.0017797851562500002\n",
376      "Epoch [64] Loss: 0.006788556293774677, LR 0.0017797851562500002\n",
377      "Epoch [65] Loss: 0.006473606571238091, LR 0.0017797851562500002\n",
378      "Epoch [66] Loss: 0.006206096486842378, LR 0.0017797851562500002\n",
379      "Epoch [67] Loss: 0.00584477313021396, LR 0.0017797851562500002\n",
380      "Epoch [68] Loss: 0.005648705267137097, LR 0.0017797851562500002\n",
381      "Epoch [69] Loss: 0.006481769871204458, LR 0.0013348388671875003\n",
382      "Epoch [70] Loss: 0.008430448618341, LR 0.0013348388671875003\n",
383      "Epoch [71] Loss: 0.006877245421105242, LR 0.0013348388671875003\n",
384      "Epoch [72] Loss: 0.005671108281740578, LR 0.0013348388671875003\n",
385      "Epoch [73] Loss: 0.004832422162624116, LR 0.0013348388671875003\n",
386      "Epoch [74] Loss: 0.004441103402604448, LR 0.0013348388671875003\n",
387      "Epoch [75] Loss: 0.004216198591475791, LR 0.0013348388671875003\n",
388      "Epoch [76] Loss: 0.004041922989711967, LR 0.0013348388671875003\n",
389      "Epoch [77] Loss: 0.003937713643337818, LR 0.0013348388671875003\n",
390      "Epoch [78] Loss: 0.010251983049068046, LR 0.0013348388671875003\n",
391      "Epoch [79] Loss: 0.01829354052848004, LR 0.0010011291503906252\n",
392      "Epoch [80] Loss: 0.006723233448561802, LR 0.0010011291503906252\n",
393      "Epoch [81] Loss: 0.004397524798170049, LR 0.0010011291503906252\n",
394      "Epoch [82] Loss: 0.0038475305476087206, LR 0.0010011291503906252\n",
395      "Epoch [83] Loss: 0.003591177945441388, LR 0.0010011291503906252\n",
396      "Epoch [84] Loss: 0.003425112014175743, LR 0.0010011291503906252\n",
397      "Epoch [85] Loss: 0.0032633850549129728, LR 0.0010011291503906252\n",
398      "Epoch [86] Loss: 0.0031762316505959693, LR 0.0010011291503906252\n",
399      "Epoch [87] Loss: 0.0030452777096565734, LR 0.0010011291503906252\n",
400      "Epoch [88] Loss: 0.002950224184220837, LR 0.0010011291503906252\n",
401      "Epoch [89] Loss: 0.002821172171450676, LR 0.0007508468627929689\n",
402      "Epoch [90] Loss: 0.002725780961361337, LR 0.0007508468627929689\n",
403      "Epoch [91] Loss: 0.002660556359493986, LR 0.0007508468627929689\n",
404      "Epoch [92] Loss: 0.0026011724946319414, LR 0.0007508468627929689\n",
405      "Epoch [93] Loss: 0.0025355776256703317, LR 0.0007508468627929689\n",
406      "Epoch [94] Loss: 0.0024825221997626283, LR 0.0007508468627929689\n",
407      "Epoch [95] Loss: 0.0024245587435174497, LR 0.0007508468627929689\n",
408      "Epoch [96] Loss: 0.002365282145879602, LR 0.0007508468627929689\n",
409      "Epoch [97] Loss: 0.0023112583984719946, LR 0.0007508468627929689\n",
410      "Epoch [98] Loss: 0.002257173682780976, LR 0.0007508468627929689\n",
411      "Epoch [99] Loss: 0.002162747085094452, LR 0.0005631351470947267\n"
412     ]
413    }
414   ],
415   "source": [
416    "epochs = 100\n",
417    "for e in range(epochs):\n",
418    "    epoch_loss = 0.\n",
419    "    for i, (data, label) in enumerate(train_data):\n",
420    "        data = data.as_in_context(ctx)\n",
421    "        label = label.as_in_context(ctx)\n",
422    "\n",
423    "        with mx.autograd.record():\n",
424    "            output = net(data)\n",
425    "            l = loss(output, label)\n",
426    "\n",
427    "        l.backward()\n",
428    "        trainer.step(data.shape[0])\n",
429    "    \n",
430    "        epoch_loss += l.mean()\n",
431    "        \n",
432    "    print(\"Epoch [{}] Loss: {}, LR {}\".format(e, epoch_loss.asscalar()/(i+1), trainer.learning_rate))"
433   ]
434  },
435  {
436   "cell_type": "markdown",
437   "metadata": {},
438   "source": [
439    "## Testing"
440   ]
441  },
442  {
443   "cell_type": "markdown",
444   "metadata": {},
445   "source": [
446    "We get a random element from the testing set"
447   ]
448  },
449  {
450   "cell_type": "code",
451   "execution_count": 17,
452   "metadata": {},
453   "outputs": [],
454   "source": [
455    "n = random.randint(0, len(test_data)-1)\n",
456    "\n",
457    "x_orig = X[split_idx+n]\n",
458    "y_orig = Y[split_idx+n]"
459   ]
460  },
461  {
462   "cell_type": "code",
463   "execution_count": 41,
464   "metadata": {},
465   "outputs": [],
466   "source": [
467    "def get_pred(x):\n",
468    "    x, _ = transform(x, x)\n",
469    "    output = net(x.as_in_context(ctx).expand_dims(axis=0))\n",
470    "\n",
471    "    # Convert output back to string\n",
472    "    pred = ''.join([vocab[int(o)] for o in output[0].argmax(axis=1).asnumpy().tolist()])\n",
473    "    return pred"
474   ]
475  },
476  {
477   "cell_type": "markdown",
478   "metadata": {},
479   "source": [
480    "Printing the result"
481   ]
482  },
483  {
484   "cell_type": "code",
485   "execution_count": 43,
486   "metadata": {},
487   "outputs": [
488    {
489     "name": "stdout",
490     "output_type": "stream",
491     "text": [
492      "X         611 671 275 871 944\n",
493      "Predicted 275 611 671 871 944\n",
494      "Label     275 611 671 871 944\n"
495     ]
496    }
497   ],
498   "source": [
499    "x_ = ' '.join(map(str,x_orig))\n",
500    "label = ' '.join(map(str,y_orig))\n",
501    "print(\"X         {}\\nPredicted {}\\nLabel     {}\".format(x_, get_pred(x_orig), label))"
502   ]
503  },
504  {
505   "cell_type": "markdown",
506   "metadata": {},
507   "source": [
508    "We can also pick our own example, and the network manages to sort it without problem:"
509   ]
510  },
511  {
512   "cell_type": "code",
513   "execution_count": 66,
514   "metadata": {},
515   "outputs": [
516    {
517     "name": "stdout",
518     "output_type": "stream",
519     "text": [
520      "10 30 130 500 999  \n"
521     ]
522    }
523   ],
524   "source": [
525    "print(get_pred(np.array([500, 30, 999, 10, 130])))"
526   ]
527  },
528  {
529   "cell_type": "markdown",
530   "metadata": {},
531   "source": [
532    "The model has even learned to generalize to examples not on the training set"
533   ]
534  },
535  {
536   "cell_type": "code",
537   "execution_count": 64,
538   "metadata": {},
539   "outputs": [
540    {
541     "name": "stdout",
542     "output_type": "stream",
543     "text": [
544      "Only four numbers: 105 202 302 501    \n"
545     ]
546    }
547   ],
548   "source": [
549    "print(\"Only four numbers:\", get_pred(np.array([105, 302, 501, 202])))"
550   ]
551  },
552  {
553   "cell_type": "markdown",
554   "metadata": {},
555   "source": [
556    "However we can see it has trouble with other edge cases:"
557   ]
558  },
559  {
560   "cell_type": "code",
561   "execution_count": 63,
562   "metadata": {},
563   "outputs": [
564    {
565     "name": "stdout",
566     "output_type": "stream",
567     "text": [
568      "Small digits: 8  0 42 28         \n",
569      "Small digits, 6 numbers: 10 0 20 82 71 115  \n"
570     ]
571    }
572   ],
573   "source": [
574    "print(\"Small digits:\", get_pred(np.array([10, 3, 5, 2, 8])))\n",
575    "print(\"Small digits, 6 numbers:\", get_pred(np.array([10, 33, 52, 21, 82, 10])))"
576   ]
577  },
578  {
579   "cell_type": "markdown",
580   "metadata": {},
581   "source": [
582    "This could be improved by adjusting the training dataset accordingly"
583   ]
584  }
585 ],
586 "metadata": {
587  "kernelspec": {
588   "display_name": "Python 3",
589   "language": "python",
590   "name": "python3"
591  },
592  "language_info": {
593   "codemirror_mode": {
594    "name": "ipython",
595    "version": 3
596   },
597   "file_extension": ".py",
598   "mimetype": "text/x-python",
599   "name": "python",
600   "nbconvert_exporter": "python",
601   "pygments_lexer": "ipython3",
602   "version": "3.6.4"
603  }
604 },
605 "nbformat": 4,
606 "nbformat_minor": 2
607}
608