1<?php
2
3declare(strict_types=1);
4
5namespace Phpml\Dataset;
6
7use Phpml\Exception\InvalidArgumentException;
8
9/**
10 * MNIST dataset: http://yann.lecun.com/exdb/mnist/
11 * original mnist dataset reader: https://github.com/AndrewCarterUK/mnist-neural-network-plain-php
12 */
13final class MnistDataset extends ArrayDataset
14{
15    private const MAGIC_IMAGE = 0x00000803;
16
17    private const MAGIC_LABEL = 0x00000801;
18
19    private const IMAGE_ROWS = 28;
20
21    private const IMAGE_COLS = 28;
22
23    public function __construct(string $imagePath, string $labelPath)
24    {
25        $this->samples = $this->readImages($imagePath);
26        $this->targets = $this->readLabels($labelPath);
27
28        if (count($this->samples) !== count($this->targets)) {
29            throw new InvalidArgumentException('Must have the same number of images and labels');
30        }
31    }
32
33    private function readImages(string $imagePath): array
34    {
35        $stream = fopen($imagePath, 'rb');
36
37        if ($stream === false) {
38            throw new InvalidArgumentException('Could not open file: '.$imagePath);
39        }
40
41        $images = [];
42
43        try {
44            $header = fread($stream, 16);
45
46            $fields = unpack('Nmagic/Nsize/Nrows/Ncols', (string) $header);
47
48            if ($fields['magic'] !== self::MAGIC_IMAGE) {
49                throw new InvalidArgumentException('Invalid magic number: '.$imagePath);
50            }
51
52            if ($fields['rows'] != self::IMAGE_ROWS) {
53                throw new InvalidArgumentException('Invalid number of image rows: '.$imagePath);
54            }
55
56            if ($fields['cols'] != self::IMAGE_COLS) {
57                throw new InvalidArgumentException('Invalid number of image cols: '.$imagePath);
58            }
59
60            for ($i = 0; $i < $fields['size']; $i++) {
61                $imageBytes = fread($stream, $fields['rows'] * $fields['cols']);
62
63                // Convert to float between 0 and 1
64                $images[] = array_map(function ($b) {
65                    return $b / 255;
66                }, array_values(unpack('C*', (string) $imageBytes)));
67            }
68        } finally {
69            fclose($stream);
70        }
71
72        return $images;
73    }
74
75    private function readLabels(string $labelPath): array
76    {
77        $stream = fopen($labelPath, 'rb');
78
79        if ($stream === false) {
80            throw new InvalidArgumentException('Could not open file: '.$labelPath);
81        }
82
83        $labels = [];
84
85        try {
86            $header = fread($stream, 8);
87
88            $fields = unpack('Nmagic/Nsize', (string) $header);
89
90            if ($fields['magic'] !== self::MAGIC_LABEL) {
91                throw new InvalidArgumentException('Invalid magic number: '.$labelPath);
92            }
93
94            $labels = fread($stream, $fields['size']);
95        } finally {
96            fclose($stream);
97        }
98
99        return array_values(unpack('C*', (string) $labels));
100    }
101}
102