1import warnings 2 3import numpy as np 4import pytest 5from mxnet.gluon import data 6 7import gluonnlp as nlp 8from gluonnlp.data import sampler as s 9 10N = 1000 11def test_sorted_sampler(): 12 dataset = data.SimpleDataset([np.random.normal(0, 1, (np.random.randint(10, 100), 1, 1)) 13 for _ in range(N)]) 14 gt_sample_id = sorted(range(len(dataset)), key=lambda i: dataset[i].shape, reverse=True) 15 sample_ret = list(s.SortedSampler([ele.shape[0] for ele in dataset])) 16 for lhs, rhs in zip(gt_sample_id, sample_ret): 17 assert lhs == rhs 18 19@pytest.mark.parametrize('seq_lengths', [[np.random.randint(10, 100) for _ in range(N)], 20 [(np.random.randint(10, 100), np.random.randint(10, 100)) 21 for _ in range(N)]]) 22@pytest.mark.parametrize('ratio', [0.0, 0.5]) 23@pytest.mark.parametrize('shuffle', [False, True]) 24@pytest.mark.parametrize('num_buckets', [1, 10, 100, 5000]) 25@pytest.mark.parametrize('bucket_scheme', [s.ConstWidthBucket(), 26 s.LinearWidthBucket(), 27 s.ExpWidthBucket()]) 28@pytest.mark.parametrize('use_average_length', [False, True]) 29@pytest.mark.parametrize('num_shards', range(4)) 30def test_fixed_bucket_sampler(seq_lengths, ratio, shuffle, num_buckets, bucket_scheme, 31 use_average_length, num_shards): 32 with warnings.catch_warnings(): 33 warnings.simplefilter("ignore") 34 sampler = s.FixedBucketSampler(seq_lengths, batch_size=8, num_buckets=num_buckets, 35 ratio=ratio, shuffle=shuffle, 36 use_average_length=use_average_length, 37 bucket_scheme=bucket_scheme, num_shards=num_shards) 38 39 print(sampler.stats()) 40 total_sampled_ids = [] 41 for batch_sample_ids in sampler: 42 if num_shards > 0: 43 assert len(batch_sample_ids) == num_shards 44 else: 45 total_sampled_ids.extend(batch_sample_ids) 46 if num_shards == 0: 47 assert len(set(total_sampled_ids)) == len(total_sampled_ids) == N 48 49@pytest.mark.parametrize('bucket_keys', [[1, 5, 10, 100], [10, 100], [200]]) 50@pytest.mark.parametrize('ratio', [0.0, 0.5]) 51@pytest.mark.parametrize('shuffle', [False, True]) 52def test_fixed_bucket_sampler_with_single_key(bucket_keys, ratio, shuffle): 53 seq_lengths = [np.random.randint(10, 100) for _ in range(N)] 54 with warnings.catch_warnings(): 55 warnings.simplefilter("ignore") 56 sampler = s.FixedBucketSampler(seq_lengths, batch_size=8, num_buckets=None, 57 bucket_keys=bucket_keys, ratio=ratio, shuffle=shuffle) 58 print(sampler.stats()) 59 total_sampled_ids = [] 60 for batch_sample_ids in sampler: 61 total_sampled_ids.extend(batch_sample_ids) 62 assert len(set(total_sampled_ids)) == len(total_sampled_ids) == N 63 64@pytest.mark.parametrize('bucket_keys', [[(1, 1), (5, 10), (10, 20), (20, 10), (100, 100)], 65 [(20, 20), (30, 15), (100, 100)], 66 [(100, 200)]]) 67@pytest.mark.parametrize('ratio', [0.0, 0.5]) 68@pytest.mark.parametrize('shuffle', [False, True]) 69def test_fixed_bucket_sampler_with_single_key(bucket_keys, ratio, shuffle): 70 seq_lengths = [(np.random.randint(10, 100), np.random.randint(10, 100)) for _ in range(N)] 71 with warnings.catch_warnings(): 72 warnings.simplefilter("ignore") 73 sampler = s.FixedBucketSampler(seq_lengths, batch_size=8, num_buckets=None, 74 bucket_keys=bucket_keys, ratio=ratio, shuffle=shuffle) 75 print(sampler.stats()) 76 total_sampled_ids = [] 77 for batch_sample_ids in sampler: 78 total_sampled_ids.extend(batch_sample_ids) 79 assert len(set(total_sampled_ids)) == len(total_sampled_ids) == N 80 81 82def test_fixed_bucket_sampler_compactness(): 83 samples = list( 84 s.FixedBucketSampler( 85 np.arange(16, 32), 8, num_buckets=2, 86 bucket_scheme=nlp.data.ConstWidthBucket())) 87 assert len(samples) == 2 88 89 90@pytest.mark.parametrize('seq_lengths', [[np.random.randint(10, 100) for _ in range(N)], 91 [(np.random.randint(10, 100), np.random.randint(10, 100)) 92 for _ in range(N)]]) 93@pytest.mark.parametrize('mult', [10, 100]) 94@pytest.mark.parametrize('batch_size', [5, 7]) 95@pytest.mark.parametrize('shuffle', [False, True]) 96def test_sorted_bucket_sampler(seq_lengths, mult, batch_size, shuffle): 97 sampler = s.SortedBucketSampler(sort_keys=seq_lengths, 98 batch_size=batch_size, 99 mult=mult, shuffle=shuffle) 100 total_sampled_ids = [] 101 for batch_sample_ids in sampler: 102 total_sampled_ids.extend(batch_sample_ids) 103 assert len(set(total_sampled_ids)) == len(total_sampled_ids) == N 104 105 106@pytest.mark.parametrize('num_samples', [30]) 107@pytest.mark.parametrize('num_parts', [3, 7]) 108@pytest.mark.parametrize('repeat', [1, 3]) 109def test_split_sampler(num_samples, num_parts, repeat): 110 total_count = 0 111 indices = [] 112 for part_idx in range(num_parts): 113 sampler = s.SplitSampler(num_samples, num_parts, part_idx, repeat=repeat) 114 count = 0 115 for i in sampler: 116 count += 1 117 indices.append(i) 118 total_count += count 119 assert count == len(sampler) 120 assert total_count == num_samples * repeat 121 assert np.allclose(sorted(indices), np.repeat(list(range(num_samples)), repeat)) 122 123 124@pytest.mark.parametrize('num_samples', [30]) 125@pytest.mark.parametrize('num_parts', [3, 7]) 126def test_split_sampler_even_size(num_samples, num_parts): 127 total_count = 0 128 indices = [] 129 for part_idx in range(num_parts): 130 sampler = s.SplitSampler(num_samples, num_parts, part_idx, even_size=True) 131 count = 0 132 for i in sampler: 133 count += 1 134 indices.append(i) 135 total_count += count 136 assert count == len(sampler) 137 print(count) 138 expected_count = int(num_samples + num_parts - 1) // num_parts * num_parts 139 assert total_count == expected_count, (total_count, expected_count) 140