1#!/usr/bin/env python 2 3import os, sys, math, random 4from collections import defaultdict 5 6if sys.version_info[0] >= 3: 7 xrange = range 8 9def exit_with_help(argv): 10 print("""\ 11Usage: {0} [options] dataset subset_size [output1] [output2] 12 13This script randomly selects a subset of the dataset. 14 15options: 16-s method : method of selection (default 0) 17 0 -- stratified selection (classification only) 18 1 -- random selection 19 20output1 : the subset (optional) 21output2 : rest of the data (optional) 22If output1 is omitted, the subset will be printed on the screen.""".format(argv[0])) 23 exit(1) 24 25def process_options(argv): 26 argc = len(argv) 27 if argc < 3: 28 exit_with_help(argv) 29 30 # default method is stratified selection 31 method = 0 32 subset_file = sys.stdout 33 rest_file = None 34 35 i = 1 36 while i < argc: 37 if argv[i][0] != "-": 38 break 39 if argv[i] == "-s": 40 i = i + 1 41 method = int(argv[i]) 42 if method not in [0,1]: 43 print("Unknown selection method {0}".format(method)) 44 exit_with_help(argv) 45 i = i + 1 46 47 dataset = argv[i] 48 subset_size = int(argv[i+1]) 49 if i+2 < argc: 50 subset_file = open(argv[i+2],'w') 51 if i+3 < argc: 52 rest_file = open(argv[i+3],'w') 53 54 return dataset, subset_size, method, subset_file, rest_file 55 56def random_selection(dataset, subset_size): 57 l = sum(1 for line in open(dataset,'r')) 58 return sorted(random.sample(xrange(l), subset_size)) 59 60def stratified_selection(dataset, subset_size): 61 labels = [line.split(None,1)[0] for line in open(dataset)] 62 label_linenums = defaultdict(list) 63 for i, label in enumerate(labels): 64 label_linenums[label] += [i] 65 66 l = len(labels) 67 remaining = subset_size 68 ret = [] 69 70 # classes with fewer data are sampled first; otherwise 71 # some rare classes may not be selected 72 for label in sorted(label_linenums, key=lambda x: len(label_linenums[x])): 73 linenums = label_linenums[label] 74 label_size = len(linenums) 75 # at least one instance per class 76 s = int(min(remaining, max(1, math.ceil(label_size*(float(subset_size)/l))))) 77 if s == 0: 78 sys.stderr.write('''\ 79Error: failed to have at least one instance per class 80 1. You may have regression data. 81 2. Your classification data is unbalanced or too small. 82Please use -s 1. 83''') 84 sys.exit(-1) 85 remaining -= s 86 ret += [linenums[i] for i in random.sample(xrange(label_size), s)] 87 return sorted(ret) 88 89def main(argv=sys.argv): 90 dataset, subset_size, method, subset_file, rest_file = process_options(argv) 91 #uncomment the following line to fix the random seed 92 #random.seed(0) 93 selected_lines = [] 94 95 if method == 0: 96 selected_lines = stratified_selection(dataset, subset_size) 97 elif method == 1: 98 selected_lines = random_selection(dataset, subset_size) 99 100 #select instances based on selected_lines 101 dataset = open(dataset,'r') 102 prev_selected_linenum = -1 103 for i in xrange(len(selected_lines)): 104 for cnt in xrange(selected_lines[i]-prev_selected_linenum-1): 105 line = dataset.readline() 106 if rest_file: 107 rest_file.write(line) 108 subset_file.write(dataset.readline()) 109 prev_selected_linenum = selected_lines[i] 110 subset_file.close() 111 112 if rest_file: 113 for line in dataset: 114 rest_file.write(line) 115 rest_file.close() 116 dataset.close() 117 118if __name__ == '__main__': 119 main(sys.argv) 120 121