1#!/usr/bin/env python
2
3import os, sys, math, random
4from collections import defaultdict
5
6if sys.version_info[0] >= 3:
7	xrange = range
8
9def exit_with_help(argv):
10	print("""\
11Usage: {0} [options] dataset subset_size [output1] [output2]
12
13This script randomly selects a subset of the dataset.
14
15options:
16-s method : method of selection (default 0)
17     0 -- stratified selection (classification only)
18     1 -- random selection
19
20output1 : the subset (optional)
21output2 : rest of the data (optional)
22If output1 is omitted, the subset will be printed on the screen.""".format(argv[0]))
23	exit(1)
24
25def process_options(argv):
26	argc = len(argv)
27	if argc < 3:
28		exit_with_help(argv)
29
30	# default method is stratified selection
31	method = 0
32	subset_file = sys.stdout
33	rest_file = None
34
35	i = 1
36	while i < argc:
37		if argv[i][0] != "-":
38			break
39		if argv[i] == "-s":
40			i = i + 1
41			method = int(argv[i])
42			if method not in [0,1]:
43				print("Unknown selection method {0}".format(method))
44				exit_with_help(argv)
45		i = i + 1
46
47	dataset = argv[i]
48	subset_size = int(argv[i+1])
49	if i+2 < argc:
50		subset_file = open(argv[i+2],'w')
51	if i+3 < argc:
52		rest_file = open(argv[i+3],'w')
53
54	return dataset, subset_size, method, subset_file, rest_file
55
56def random_selection(dataset, subset_size):
57	l = sum(1 for line in open(dataset,'r'))
58	return sorted(random.sample(xrange(l), subset_size))
59
60def stratified_selection(dataset, subset_size):
61	labels = [line.split(None,1)[0] for line in open(dataset)]
62	label_linenums = defaultdict(list)
63	for i, label in enumerate(labels):
64		label_linenums[label] += [i]
65
66	l = len(labels)
67	remaining = subset_size
68	ret = []
69
70	# classes with fewer data are sampled first; otherwise
71	# some rare classes may not be selected
72	for label in sorted(label_linenums, key=lambda x: len(label_linenums[x])):
73		linenums = label_linenums[label]
74		label_size = len(linenums)
75		# at least one instance per class
76		s = int(min(remaining, max(1, math.ceil(label_size*(float(subset_size)/l)))))
77		if s == 0:
78			sys.stderr.write('''\
79Error: failed to have at least one instance per class
80    1. You may have regression data.
81    2. Your classification data is unbalanced or too small.
82Please use -s 1.
83''')
84			sys.exit(-1)
85		remaining -= s
86		ret += [linenums[i] for i in random.sample(xrange(label_size), s)]
87	return sorted(ret)
88
89def main(argv=sys.argv):
90	dataset, subset_size, method, subset_file, rest_file = process_options(argv)
91	#uncomment the following line to fix the random seed
92	#random.seed(0)
93	selected_lines = []
94
95	if method == 0:
96		selected_lines = stratified_selection(dataset, subset_size)
97	elif method == 1:
98		selected_lines = random_selection(dataset, subset_size)
99
100	#select instances based on selected_lines
101	dataset = open(dataset,'r')
102	prev_selected_linenum = -1
103	for i in xrange(len(selected_lines)):
104		for cnt in xrange(selected_lines[i]-prev_selected_linenum-1):
105			line = dataset.readline()
106			if rest_file:
107				rest_file.write(line)
108		subset_file.write(dataset.readline())
109		prev_selected_linenum = selected_lines[i]
110	subset_file.close()
111
112	if rest_file:
113		for line in dataset:
114			rest_file.write(line)
115		rest_file.close()
116	dataset.close()
117
118if __name__ == '__main__':
119	main(sys.argv)
120
121