1import warnings
2
3import numpy as np
4import pytest
5from mxnet.gluon import data
6
7import gluonnlp as nlp
8from gluonnlp.data import sampler as s
9
10N = 1000
11def test_sorted_sampler():
12    dataset = data.SimpleDataset([np.random.normal(0, 1, (np.random.randint(10, 100), 1, 1))
13                                  for _ in range(N)])
14    gt_sample_id = sorted(range(len(dataset)), key=lambda i: dataset[i].shape, reverse=True)
15    sample_ret = list(s.SortedSampler([ele.shape[0] for ele in dataset]))
16    for lhs, rhs in zip(gt_sample_id, sample_ret):
17        assert lhs == rhs
18
19@pytest.mark.parametrize('seq_lengths', [[np.random.randint(10, 100) for _ in range(N)],
20                                         [(np.random.randint(10, 100), np.random.randint(10, 100))
21                                           for _ in range(N)]])
22@pytest.mark.parametrize('ratio', [0.0, 0.5])
23@pytest.mark.parametrize('shuffle', [False, True])
24@pytest.mark.parametrize('num_buckets', [1, 10, 100, 5000])
25@pytest.mark.parametrize('bucket_scheme', [s.ConstWidthBucket(),
26                                           s.LinearWidthBucket(),
27                                           s.ExpWidthBucket()])
28@pytest.mark.parametrize('use_average_length', [False, True])
29@pytest.mark.parametrize('num_shards', range(4))
30def test_fixed_bucket_sampler(seq_lengths, ratio, shuffle, num_buckets, bucket_scheme,
31                              use_average_length, num_shards):
32    with warnings.catch_warnings():
33        warnings.simplefilter("ignore")
34        sampler = s.FixedBucketSampler(seq_lengths, batch_size=8, num_buckets=num_buckets,
35                                       ratio=ratio, shuffle=shuffle,
36                                       use_average_length=use_average_length,
37                                       bucket_scheme=bucket_scheme, num_shards=num_shards)
38
39    print(sampler.stats())
40    total_sampled_ids = []
41    for batch_sample_ids in sampler:
42        if num_shards > 0:
43            assert len(batch_sample_ids) == num_shards
44        else:
45            total_sampled_ids.extend(batch_sample_ids)
46    if num_shards == 0:
47        assert len(set(total_sampled_ids)) == len(total_sampled_ids) == N
48
49@pytest.mark.parametrize('bucket_keys', [[1, 5, 10, 100], [10, 100], [200]])
50@pytest.mark.parametrize('ratio', [0.0, 0.5])
51@pytest.mark.parametrize('shuffle', [False, True])
52def test_fixed_bucket_sampler_with_single_key(bucket_keys, ratio, shuffle):
53    seq_lengths = [np.random.randint(10, 100) for _ in range(N)]
54    with warnings.catch_warnings():
55        warnings.simplefilter("ignore")
56        sampler = s.FixedBucketSampler(seq_lengths, batch_size=8, num_buckets=None,
57                                       bucket_keys=bucket_keys, ratio=ratio, shuffle=shuffle)
58    print(sampler.stats())
59    total_sampled_ids = []
60    for batch_sample_ids in sampler:
61        total_sampled_ids.extend(batch_sample_ids)
62    assert len(set(total_sampled_ids)) == len(total_sampled_ids) == N
63
64@pytest.mark.parametrize('bucket_keys', [[(1, 1), (5, 10), (10, 20), (20, 10), (100, 100)],
65                                         [(20, 20), (30, 15), (100, 100)],
66                                         [(100, 200)]])
67@pytest.mark.parametrize('ratio', [0.0, 0.5])
68@pytest.mark.parametrize('shuffle', [False, True])
69def test_fixed_bucket_sampler_with_single_key(bucket_keys, ratio, shuffle):
70    seq_lengths = [(np.random.randint(10, 100), np.random.randint(10, 100)) for _ in range(N)]
71    with warnings.catch_warnings():
72        warnings.simplefilter("ignore")
73        sampler = s.FixedBucketSampler(seq_lengths, batch_size=8, num_buckets=None,
74                                       bucket_keys=bucket_keys, ratio=ratio, shuffle=shuffle)
75    print(sampler.stats())
76    total_sampled_ids = []
77    for batch_sample_ids in sampler:
78        total_sampled_ids.extend(batch_sample_ids)
79    assert len(set(total_sampled_ids)) == len(total_sampled_ids) == N
80
81
82def test_fixed_bucket_sampler_compactness():
83    samples = list(
84        s.FixedBucketSampler(
85            np.arange(16, 32), 8, num_buckets=2,
86            bucket_scheme=nlp.data.ConstWidthBucket()))
87    assert len(samples) == 2
88
89
90@pytest.mark.parametrize('seq_lengths', [[np.random.randint(10, 100) for _ in range(N)],
91                                         [(np.random.randint(10, 100), np.random.randint(10, 100))
92                                          for _ in range(N)]])
93@pytest.mark.parametrize('mult', [10, 100])
94@pytest.mark.parametrize('batch_size', [5, 7])
95@pytest.mark.parametrize('shuffle', [False, True])
96def test_sorted_bucket_sampler(seq_lengths, mult, batch_size, shuffle):
97    sampler = s.SortedBucketSampler(sort_keys=seq_lengths,
98                                    batch_size=batch_size,
99                                    mult=mult, shuffle=shuffle)
100    total_sampled_ids = []
101    for batch_sample_ids in sampler:
102        total_sampled_ids.extend(batch_sample_ids)
103    assert len(set(total_sampled_ids)) == len(total_sampled_ids) == N
104
105
106@pytest.mark.parametrize('num_samples', [30])
107@pytest.mark.parametrize('num_parts', [3, 7])
108@pytest.mark.parametrize('repeat', [1, 3])
109def test_split_sampler(num_samples, num_parts, repeat):
110    total_count = 0
111    indices = []
112    for part_idx in range(num_parts):
113        sampler = s.SplitSampler(num_samples, num_parts, part_idx, repeat=repeat)
114        count = 0
115        for i in sampler:
116            count += 1
117            indices.append(i)
118        total_count += count
119        assert count == len(sampler)
120    assert total_count == num_samples * repeat
121    assert np.allclose(sorted(indices), np.repeat(list(range(num_samples)), repeat))
122
123
124@pytest.mark.parametrize('num_samples', [30])
125@pytest.mark.parametrize('num_parts', [3, 7])
126def test_split_sampler_even_size(num_samples, num_parts):
127    total_count = 0
128    indices = []
129    for part_idx in range(num_parts):
130        sampler = s.SplitSampler(num_samples, num_parts, part_idx, even_size=True)
131        count = 0
132        for i in sampler:
133            count += 1
134            indices.append(i)
135        total_count += count
136        assert count == len(sampler)
137        print(count)
138    expected_count = int(num_samples + num_parts - 1) // num_parts * num_parts
139    assert total_count == expected_count, (total_count, expected_count)
140