1{ 2 "cells": [ 3 { 4 "cell_type": "markdown", 5 "metadata": {}, 6 "source": [ 7 "# Using a bi-lstm to sort a sequence of integers" 8 ] 9 }, 10 { 11 "cell_type": "code", 12 "execution_count": 1, 13 "metadata": {}, 14 "outputs": [], 15 "source": [ 16 "import random\n", 17 "import string\n", 18 "\n", 19 "import mxnet as mx\n", 20 "from mxnet import gluon, nd\n", 21 "import numpy as np" 22 ] 23 }, 24 { 25 "cell_type": "markdown", 26 "metadata": {}, 27 "source": [ 28 "## Data Preparation" 29 ] 30 }, 31 { 32 "cell_type": "code", 33 "execution_count": 2, 34 "metadata": {}, 35 "outputs": [], 36 "source": [ 37 "max_num = 999\n", 38 "dataset_size = 60000\n", 39 "seq_len = 5\n", 40 "split = 0.8\n", 41 "batch_size = 512\n", 42 "ctx = mx.gpu() if mx.context.num_gpus() > 0 else mx.cpu()" 43 ] 44 }, 45 { 46 "cell_type": "markdown", 47 "metadata": {}, 48 "source": [ 49 "We are getting a dataset of **dataset_size** sequences of integers of length **seq_len** between **0** and **max_num**. We use **split*100%** of them for training and the rest for testing.\n", 50 "\n", 51 "\n", 52 "For example:\n", 53 "\n", 54 "50 10 200 999 30\n", 55 "\n", 56 "Should return\n", 57 "\n", 58 "10 30 50 200 999" 59 ] 60 }, 61 { 62 "cell_type": "code", 63 "execution_count": 3, 64 "metadata": {}, 65 "outputs": [], 66 "source": [ 67 "X = mx.random.uniform(low=0, high=max_num, shape=(dataset_size, seq_len)).astype('int32').asnumpy()\n", 68 "Y = X.copy()\n", 69 "Y.sort() #Let's sort X to get the target" 70 ] 71 }, 72 { 73 "cell_type": "code", 74 "execution_count": 4, 75 "metadata": {}, 76 "outputs": [ 77 { 78 "name": "stdout", 79 "output_type": "stream", 80 "text": [ 81 "Input [548, 592, 714, 843, 602]\n", 82 "Target [548, 592, 602, 714, 843]\n" 83 ] 84 } 85 ], 86 "source": [ 87 "print(\"Input {}\\nTarget {}\".format(X[0].tolist(), Y[0].tolist()))" 88 ] 89 }, 90 { 91 "cell_type": "markdown", 92 "metadata": {}, 93 "source": [ 94 "For the purpose of training, we encode the input as characters rather than numbers" 95 ] 96 }, 97 { 98 "cell_type": "code", 99 "execution_count": 5, 100 "metadata": {}, 101 "outputs": [ 102 { 103 "name": "stdout", 104 "output_type": "stream", 105 "text": [ 106 "0123456789 \n", 107 "{'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9, ' ': 10}\n" 108 ] 109 } 110 ], 111 "source": [ 112 "vocab = string.digits + \" \"\n", 113 "print(vocab)\n", 114 "vocab_idx = { c:i for i,c in enumerate(vocab)}\n", 115 "print(vocab_idx)" 116 ] 117 }, 118 { 119 "cell_type": "markdown", 120 "metadata": {}, 121 "source": [ 122 "We write a transform that will convert our numbers into text of maximum length **max_len**, and one-hot encode the characters.\n", 123 "For example:\n", 124 "\n", 125 "\"30 10\" corresponding indices are [3, 0, 10, 1, 0]\n", 126 "\n", 127 "We then one hot encode that and get a matrix representation of our input. We don't need to encode our target as the loss we are going to use support sparse labels" 128 ] 129 }, 130 { 131 "cell_type": "code", 132 "execution_count": 6, 133 "metadata": {}, 134 "outputs": [ 135 { 136 "name": "stdout", 137 "output_type": "stream", 138 "text": [ 139 "Maximum length of the string: 19\n" 140 ] 141 } 142 ], 143 "source": [ 144 "max_len = len(str(max_num))*seq_len+(seq_len-1)\n", 145 "print(\"Maximum length of the string: %s\" % max_len)" 146 ] 147 }, 148 { 149 "cell_type": "code", 150 "execution_count": 7, 151 "metadata": {}, 152 "outputs": [], 153 "source": [ 154 "def transform(x, y):\n", 155 " x_string = ' '.join(map(str, x.tolist()))\n", 156 " x_string_padded = x_string + ' '*(max_len-len(x_string))\n", 157 " x = [vocab_idx[c] for c in x_string_padded]\n", 158 " y_string = ' '.join(map(str, y.tolist()))\n", 159 " y_string_padded = y_string + ' '*(max_len-len(y_string))\n", 160 " y = [vocab_idx[c] for c in y_string_padded]\n", 161 " return mx.nd.one_hot(mx.nd.array(x), len(vocab)), mx.nd.array(y)" 162 ] 163 }, 164 { 165 "cell_type": "code", 166 "execution_count": 8, 167 "metadata": {}, 168 "outputs": [], 169 "source": [ 170 "split_idx = int(split*len(X))\n", 171 "train_dataset = gluon.data.ArrayDataset(X[:split_idx], Y[:split_idx]).transform(transform)\n", 172 "test_dataset = gluon.data.ArrayDataset(X[split_idx:], Y[split_idx:]).transform(transform)" 173 ] 174 }, 175 { 176 "cell_type": "code", 177 "execution_count": 9, 178 "metadata": {}, 179 "outputs": [ 180 { 181 "name": "stdout", 182 "output_type": "stream", 183 "text": [ 184 "Input [548 592 714 843 602]\n", 185 "Transformed data Input \n", 186 "[[0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", 187 " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 188 " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 189 " [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 190 " [0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", 191 " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n", 192 " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 193 " [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 194 " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", 195 " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 196 " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 197 " [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 198 " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 199 " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 200 " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 201 " [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 202 " [0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", 203 " [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 204 " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]]\n", 205 "<NDArray 19x11 @cpu(0)>\n", 206 "Target [548 592 602 714 843]\n", 207 "Transformed data Target \n", 208 "[ 5. 4. 8. 10. 5. 9. 2. 10. 6. 0. 2. 10. 7. 1. 4. 10. 8. 4.\n", 209 " 3.]\n", 210 "<NDArray 19 @cpu(0)>\n" 211 ] 212 } 213 ], 214 "source": [ 215 "print(\"Input {}\".format(X[0]))\n", 216 "print(\"Transformed data Input {}\".format(train_dataset[0][0]))\n", 217 "print(\"Target {}\".format(Y[0]))\n", 218 "print(\"Transformed data Target {}\".format(train_dataset[0][1]))" 219 ] 220 }, 221 { 222 "cell_type": "code", 223 "execution_count": 10, 224 "metadata": {}, 225 "outputs": [], 226 "source": [ 227 "train_data = gluon.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=20, last_batch='rollover')\n", 228 "test_data = gluon.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=5, last_batch='rollover')" 229 ] 230 }, 231 { 232 "cell_type": "markdown", 233 "metadata": {}, 234 "source": [ 235 "## Creating the network" 236 ] 237 }, 238 { 239 "cell_type": "code", 240 "execution_count": 11, 241 "metadata": {}, 242 "outputs": [], 243 "source": [ 244 "net = gluon.nn.HybridSequential()\n", 245 "with net.name_scope():\n", 246 " net.add(\n", 247 " gluon.rnn.LSTM(hidden_size=128, num_layers=2, layout='NTC', bidirectional=True),\n", 248 " gluon.nn.Dense(len(vocab), flatten=False)\n", 249 " )" 250 ] 251 }, 252 { 253 "cell_type": "code", 254 "execution_count": 12, 255 "metadata": {}, 256 "outputs": [], 257 "source": [ 258 "net.initialize(mx.init.Xavier(), ctx=ctx)" 259 ] 260 }, 261 { 262 "cell_type": "code", 263 "execution_count": 13, 264 "metadata": {}, 265 "outputs": [], 266 "source": [ 267 "loss = gluon.loss.SoftmaxCELoss()" 268 ] 269 }, 270 { 271 "cell_type": "markdown", 272 "metadata": {}, 273 "source": [ 274 "We use a learning rate schedule to improve the convergence of the model" 275 ] 276 }, 277 { 278 "cell_type": "code", 279 "execution_count": 14, 280 "metadata": {}, 281 "outputs": [], 282 "source": [ 283 "schedule = mx.lr_scheduler.FactorScheduler(step=len(train_data)*10, factor=0.75)\n", 284 "schedule.base_lr = 0.01" 285 ] 286 }, 287 { 288 "cell_type": "code", 289 "execution_count": 15, 290 "metadata": {}, 291 "outputs": [], 292 "source": [ 293 "trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate':0.01, 'lr_scheduler':schedule})" 294 ] 295 }, 296 { 297 "cell_type": "markdown", 298 "metadata": {}, 299 "source": [ 300 "## Training loop" 301 ] 302 }, 303 { 304 "cell_type": "code", 305 "execution_count": 16, 306 "metadata": {}, 307 "outputs": [ 308 { 309 "name": "stdout", 310 "output_type": "stream", 311 "text": [ 312 "Epoch [0] Loss: 1.6627886372227823, LR 0.01\n", 313 "Epoch [1] Loss: 1.210370733382854, LR 0.01\n", 314 "Epoch [2] Loss: 0.9692377131035987, LR 0.01\n", 315 "Epoch [3] Loss: 0.7976046623067653, LR 0.01\n", 316 "Epoch [4] Loss: 0.5714595343476983, LR 0.01\n", 317 "Epoch [5] Loss: 0.4458411196444897, LR 0.01\n", 318 "Epoch [6] Loss: 0.36039798817736035, LR 0.01\n", 319 "Epoch [7] Loss: 0.32665719377233626, LR 0.01\n", 320 "Epoch [8] Loss: 0.262064205702915, LR 0.01\n", 321 "Epoch [9] Loss: 0.22285924059279422, LR 0.0075\n", 322 "Epoch [10] Loss: 0.19018426854559717, LR 0.0075\n", 323 "Epoch [11] Loss: 0.1718730723604243, LR 0.0075\n", 324 "Epoch [12] Loss: 0.15736752171670237, LR 0.0075\n", 325 "Epoch [13] Loss: 0.14579375246737866, LR 0.0075\n", 326 "Epoch [14] Loss: 0.13546599733068587, LR 0.0075\n", 327 "Epoch [15] Loss: 0.12490207590955368, LR 0.0075\n", 328 "Epoch [16] Loss: 0.11803316300915133, LR 0.0075\n", 329 "Epoch [17] Loss: 0.10653189395336395, LR 0.0075\n", 330 "Epoch [18] Loss: 0.10514750379197141, LR 0.0075\n", 331 "Epoch [19] Loss: 0.09590611559279422, LR 0.005625\n", 332 "Epoch [20] Loss: 0.08146028108494256, LR 0.005625\n", 333 "Epoch [21] Loss: 0.07707348782965477, LR 0.005625\n", 334 "Epoch [22] Loss: 0.07206193436967566, LR 0.005625\n", 335 "Epoch [23] Loss: 0.07001185417175293, LR 0.005625\n", 336 "Epoch [24] Loss: 0.06797058351578252, LR 0.005625\n", 337 "Epoch [25] Loss: 0.0649358110224947, LR 0.005625\n", 338 "Epoch [26] Loss: 0.06219124286732775, LR 0.005625\n", 339 "Epoch [27] Loss: 0.06075144828634059, LR 0.005625\n", 340 "Epoch [28] Loss: 0.05711334495134251, LR 0.005625\n", 341 "Epoch [29] Loss: 0.054747099572039666, LR 0.00421875\n", 342 "Epoch [30] Loss: 0.0441775271233092, LR 0.00421875\n", 343 "Epoch [31] Loss: 0.041551097910454936, LR 0.00421875\n", 344 "Epoch [32] Loss: 0.04095017269093503, LR 0.00421875\n", 345 "Epoch [33] Loss: 0.04045371045457556, LR 0.00421875\n", 346 "Epoch [34] Loss: 0.038867686657195394, LR 0.00421875\n", 347 "Epoch [35] Loss: 0.038131744303601854, LR 0.00421875\n", 348 "Epoch [36] Loss: 0.039834817250569664, LR 0.00421875\n", 349 "Epoch [37] Loss: 0.03669035941996473, LR 0.00421875\n", 350 "Epoch [38] Loss: 0.03373505967728635, LR 0.00421875\n", 351 "Epoch [39] Loss: 0.03164981273894615, LR 0.0031640625\n", 352 "Epoch [40] Loss: 0.025532766055035336, LR 0.0031640625\n", 353 "Epoch [41] Loss: 0.022659448867148543, LR 0.0031640625\n", 354 "Epoch [42] Loss: 0.02307056112492338, LR 0.0031640625\n", 355 "Epoch [43] Loss: 0.02236944056571798, LR 0.0031640625\n", 356 "Epoch [44] Loss: 0.022204211963120328, LR 0.0031640625\n", 357 "Epoch [45] Loss: 0.02262336903430046, LR 0.0031640625\n", 358 "Epoch [46] Loss: 0.02253308448385685, LR 0.0031640625\n", 359 "Epoch [47] Loss: 0.025286573044797207, LR 0.0031640625\n", 360 "Epoch [48] Loss: 0.02439300988310127, LR 0.0031640625\n", 361 "Epoch [49] Loss: 0.017976388018181983, LR 0.002373046875\n", 362 "Epoch [50] Loss: 0.014343131095805067, LR 0.002373046875\n", 363 "Epoch [51] Loss: 0.013039355582379281, LR 0.002373046875\n", 364 "Epoch [52] Loss: 0.011884741885687715, LR 0.002373046875\n", 365 "Epoch [53] Loss: 0.011438189668858305, LR 0.002373046875\n", 366 "Epoch [54] Loss: 0.011447292693117832, LR 0.002373046875\n", 367 "Epoch [55] Loss: 0.014212571560068334, LR 0.002373046875\n", 368 "Epoch [56] Loss: 0.019900493724371797, LR 0.002373046875\n", 369 "Epoch [57] Loss: 0.02102568301748722, LR 0.002373046875\n", 370 "Epoch [58] Loss: 0.01346214400961044, LR 0.002373046875\n", 371 "Epoch [59] Loss: 0.010107964911359422, LR 0.0017797851562500002\n", 372 "Epoch [60] Loss: 0.008353193600972494, LR 0.0017797851562500002\n", 373 "Epoch [61] Loss: 0.007678258292218472, LR 0.0017797851562500002\n", 374 "Epoch [62] Loss: 0.007262124660167288, LR 0.0017797851562500002\n", 375 "Epoch [63] Loss: 0.00705223578087827, LR 0.0017797851562500002\n", 376 "Epoch [64] Loss: 0.006788556293774677, LR 0.0017797851562500002\n", 377 "Epoch [65] Loss: 0.006473606571238091, LR 0.0017797851562500002\n", 378 "Epoch [66] Loss: 0.006206096486842378, LR 0.0017797851562500002\n", 379 "Epoch [67] Loss: 0.00584477313021396, LR 0.0017797851562500002\n", 380 "Epoch [68] Loss: 0.005648705267137097, LR 0.0017797851562500002\n", 381 "Epoch [69] Loss: 0.006481769871204458, LR 0.0013348388671875003\n", 382 "Epoch [70] Loss: 0.008430448618341, LR 0.0013348388671875003\n", 383 "Epoch [71] Loss: 0.006877245421105242, LR 0.0013348388671875003\n", 384 "Epoch [72] Loss: 0.005671108281740578, LR 0.0013348388671875003\n", 385 "Epoch [73] Loss: 0.004832422162624116, LR 0.0013348388671875003\n", 386 "Epoch [74] Loss: 0.004441103402604448, LR 0.0013348388671875003\n", 387 "Epoch [75] Loss: 0.004216198591475791, LR 0.0013348388671875003\n", 388 "Epoch [76] Loss: 0.004041922989711967, LR 0.0013348388671875003\n", 389 "Epoch [77] Loss: 0.003937713643337818, LR 0.0013348388671875003\n", 390 "Epoch [78] Loss: 0.010251983049068046, LR 0.0013348388671875003\n", 391 "Epoch [79] Loss: 0.01829354052848004, LR 0.0010011291503906252\n", 392 "Epoch [80] Loss: 0.006723233448561802, LR 0.0010011291503906252\n", 393 "Epoch [81] Loss: 0.004397524798170049, LR 0.0010011291503906252\n", 394 "Epoch [82] Loss: 0.0038475305476087206, LR 0.0010011291503906252\n", 395 "Epoch [83] Loss: 0.003591177945441388, LR 0.0010011291503906252\n", 396 "Epoch [84] Loss: 0.003425112014175743, LR 0.0010011291503906252\n", 397 "Epoch [85] Loss: 0.0032633850549129728, LR 0.0010011291503906252\n", 398 "Epoch [86] Loss: 0.0031762316505959693, LR 0.0010011291503906252\n", 399 "Epoch [87] Loss: 0.0030452777096565734, LR 0.0010011291503906252\n", 400 "Epoch [88] Loss: 0.002950224184220837, LR 0.0010011291503906252\n", 401 "Epoch [89] Loss: 0.002821172171450676, LR 0.0007508468627929689\n", 402 "Epoch [90] Loss: 0.002725780961361337, LR 0.0007508468627929689\n", 403 "Epoch [91] Loss: 0.002660556359493986, LR 0.0007508468627929689\n", 404 "Epoch [92] Loss: 0.0026011724946319414, LR 0.0007508468627929689\n", 405 "Epoch [93] Loss: 0.0025355776256703317, LR 0.0007508468627929689\n", 406 "Epoch [94] Loss: 0.0024825221997626283, LR 0.0007508468627929689\n", 407 "Epoch [95] Loss: 0.0024245587435174497, LR 0.0007508468627929689\n", 408 "Epoch [96] Loss: 0.002365282145879602, LR 0.0007508468627929689\n", 409 "Epoch [97] Loss: 0.0023112583984719946, LR 0.0007508468627929689\n", 410 "Epoch [98] Loss: 0.002257173682780976, LR 0.0007508468627929689\n", 411 "Epoch [99] Loss: 0.002162747085094452, LR 0.0005631351470947267\n" 412 ] 413 } 414 ], 415 "source": [ 416 "epochs = 100\n", 417 "for e in range(epochs):\n", 418 " epoch_loss = 0.\n", 419 " for i, (data, label) in enumerate(train_data):\n", 420 " data = data.as_in_context(ctx)\n", 421 " label = label.as_in_context(ctx)\n", 422 "\n", 423 " with mx.autograd.record():\n", 424 " output = net(data)\n", 425 " l = loss(output, label)\n", 426 "\n", 427 " l.backward()\n", 428 " trainer.step(data.shape[0])\n", 429 " \n", 430 " epoch_loss += l.mean()\n", 431 " \n", 432 " print(\"Epoch [{}] Loss: {}, LR {}\".format(e, epoch_loss.asscalar()/(i+1), trainer.learning_rate))" 433 ] 434 }, 435 { 436 "cell_type": "markdown", 437 "metadata": {}, 438 "source": [ 439 "## Testing" 440 ] 441 }, 442 { 443 "cell_type": "markdown", 444 "metadata": {}, 445 "source": [ 446 "We get a random element from the testing set" 447 ] 448 }, 449 { 450 "cell_type": "code", 451 "execution_count": 17, 452 "metadata": {}, 453 "outputs": [], 454 "source": [ 455 "n = random.randint(0, len(test_data)-1)\n", 456 "\n", 457 "x_orig = X[split_idx+n]\n", 458 "y_orig = Y[split_idx+n]" 459 ] 460 }, 461 { 462 "cell_type": "code", 463 "execution_count": 41, 464 "metadata": {}, 465 "outputs": [], 466 "source": [ 467 "def get_pred(x):\n", 468 " x, _ = transform(x, x)\n", 469 " output = net(x.as_in_context(ctx).expand_dims(axis=0))\n", 470 "\n", 471 " # Convert output back to string\n", 472 " pred = ''.join([vocab[int(o)] for o in output[0].argmax(axis=1).asnumpy().tolist()])\n", 473 " return pred" 474 ] 475 }, 476 { 477 "cell_type": "markdown", 478 "metadata": {}, 479 "source": [ 480 "Printing the result" 481 ] 482 }, 483 { 484 "cell_type": "code", 485 "execution_count": 43, 486 "metadata": {}, 487 "outputs": [ 488 { 489 "name": "stdout", 490 "output_type": "stream", 491 "text": [ 492 "X 611 671 275 871 944\n", 493 "Predicted 275 611 671 871 944\n", 494 "Label 275 611 671 871 944\n" 495 ] 496 } 497 ], 498 "source": [ 499 "x_ = ' '.join(map(str,x_orig))\n", 500 "label = ' '.join(map(str,y_orig))\n", 501 "print(\"X {}\\nPredicted {}\\nLabel {}\".format(x_, get_pred(x_orig), label))" 502 ] 503 }, 504 { 505 "cell_type": "markdown", 506 "metadata": {}, 507 "source": [ 508 "We can also pick our own example, and the network manages to sort it without problem:" 509 ] 510 }, 511 { 512 "cell_type": "code", 513 "execution_count": 66, 514 "metadata": {}, 515 "outputs": [ 516 { 517 "name": "stdout", 518 "output_type": "stream", 519 "text": [ 520 "10 30 130 500 999 \n" 521 ] 522 } 523 ], 524 "source": [ 525 "print(get_pred(np.array([500, 30, 999, 10, 130])))" 526 ] 527 }, 528 { 529 "cell_type": "markdown", 530 "metadata": {}, 531 "source": [ 532 "The model has even learned to generalize to examples not on the training set" 533 ] 534 }, 535 { 536 "cell_type": "code", 537 "execution_count": 64, 538 "metadata": {}, 539 "outputs": [ 540 { 541 "name": "stdout", 542 "output_type": "stream", 543 "text": [ 544 "Only four numbers: 105 202 302 501 \n" 545 ] 546 } 547 ], 548 "source": [ 549 "print(\"Only four numbers:\", get_pred(np.array([105, 302, 501, 202])))" 550 ] 551 }, 552 { 553 "cell_type": "markdown", 554 "metadata": {}, 555 "source": [ 556 "However we can see it has trouble with other edge cases:" 557 ] 558 }, 559 { 560 "cell_type": "code", 561 "execution_count": 63, 562 "metadata": {}, 563 "outputs": [ 564 { 565 "name": "stdout", 566 "output_type": "stream", 567 "text": [ 568 "Small digits: 8 0 42 28 \n", 569 "Small digits, 6 numbers: 10 0 20 82 71 115 \n" 570 ] 571 } 572 ], 573 "source": [ 574 "print(\"Small digits:\", get_pred(np.array([10, 3, 5, 2, 8])))\n", 575 "print(\"Small digits, 6 numbers:\", get_pred(np.array([10, 33, 52, 21, 82, 10])))" 576 ] 577 }, 578 { 579 "cell_type": "markdown", 580 "metadata": {}, 581 "source": [ 582 "This could be improved by adjusting the training dataset accordingly" 583 ] 584 } 585 ], 586 "metadata": { 587 "kernelspec": { 588 "display_name": "Python 3", 589 "language": "python", 590 "name": "python3" 591 }, 592 "language_info": { 593 "codemirror_mode": { 594 "name": "ipython", 595 "version": 3 596 }, 597 "file_extension": ".py", 598 "mimetype": "text/x-python", 599 "name": "python", 600 "nbconvert_exporter": "python", 601 "pygments_lexer": "ipython3", 602 "version": "3.6.4" 603 } 604 }, 605 "nbformat": 4, 606 "nbformat_minor": 2 607} 608