1# pylint: disable=invalid-name
2"""Command line options of job submission script."""
3import os
4import argparse
5
6def get_cache_file_set(args):
7    """Get the list of files to be cached.
8
9    Parameters
10    ----------
11    args: ArgumentParser.Argument
12        The arguments returned by the parser.
13
14    Returns
15    -------
16    cache_file_set: set of str
17        The set of files to be cached to local execution environment.
18
19    command: list of str
20        The commands that get rewritten after the file cache is used.
21    """
22    fset = set()
23    cmds = []
24    if args.auto_file_cache:
25        for i in range(len(args.command)):
26            fname = args.command[i]
27            if os.path.exists(fname):
28                fset.add(fname)
29                cmds.append('./' + fname.split('/')[-1])
30            else:
31                cmds.append(fname)
32
33    for fname in args.files:
34        if os.path.exists(fname):
35            fset.add(fname)
36    return fset, cmds
37
38
39def get_memory_mb(mem_str):
40    """Get the memory in MB from memory string.
41
42    mem_str: str
43        String representation of memory requirement.
44
45    Returns
46    -------
47    mem_mb: int
48        Memory requirement in MB.
49    """
50    mem_str = mem_str.lower()
51    if mem_str.endswith('g'):
52        return int(float(mem_str[:-1]) * 1024)
53    elif mem_str.endswith('m'):
54        return int(float(mem_str[:-1]))
55    else:
56        msg = 'Invalid memory specification %s, need to be a number follows g or m' % mem_str
57        raise RuntimeError(msg)
58
59
60def get_opts(args=None):
61    """Get options to launch the job.
62
63    Returns
64    -------
65    args: ArgumentParser.Argument
66        The arguments returned by the parser.
67
68    cache_file_set: set of str
69        The set of files to be cached to local execution environment.
70    """
71    parser = argparse.ArgumentParser(description='DMLC job submission.')
72    parser.add_argument('--cluster', type=str,
73                        choices=['yarn', 'slurm', 'mpi', 'sge', 'local', 'ssh', 'mesos', 'kubernetes'],
74                        help=('Cluster type of this submission,' +
75                              'default to env variable ${DMLC_SUBMIT_CLUSTER}.'))
76    parser.add_argument('--num-workers', required=True, type=int,
77                        help='Number of worker proccess to be launched.')
78    parser.add_argument('--worker-cores', default=1, type=int,
79                        help='Number of cores to be allocated for each worker process.')
80    parser.add_argument('--worker-memory', default='1g', type=str,
81                        help=('Memory need to be allocated for each worker,' +
82                              ' need to ends with g or m'))
83    parser.add_argument('--num-servers', default=0, type=int,
84                        help='Number of server process to be launched. Only used in PS jobs.')
85    parser.add_argument('--server-cores', default=1, type=int,
86                        help=('Number of cores to be allocated for each server process.' +
87                              'Only used in PS jobs.'))
88    parser.add_argument('--server-memory', default='1g', type=str,
89                        help=('Memory need to be allocated for each server, ' +
90                              'need to ends with g or m.'))
91    parser.add_argument('--jobname', default=None, type=str, help='Name of the job.')
92    parser.add_argument('--queue', default='default', type=str,
93                        help='The submission queue the job should goes to.')
94    parser.add_argument('--log-level', default='INFO', type=str,
95                        choices=['INFO', 'DEBUG'],
96                        help='Logging level of the logger.')
97    parser.add_argument('--log-file', default=None, type=str,
98                        help=('Output log to the specific log file, ' +
99                              'the log is still printed on stderr.'))
100    parser.add_argument('--host-ip', default=None, type=str,
101                        help=('Host IP addressed, this is only needed ' +
102                              'if the host IP cannot be automatically guessed.'))
103    parser.add_argument('--hdfs-tempdir', default='/tmp', type=str,
104                        help=('Temporary directory in HDFS, ' +
105                              ' only needed in YARN mode.'))
106    parser.add_argument('--host-file', default=None, type=str,
107                        help=('The file contains the list of hostnames, needed for MPI and ssh.'))
108    parser.add_argument('--sge-log-dir', default=None, type=str,
109                        help=('Log directory of SGD jobs, only needed in SGE mode.'))
110    parser.add_argument(
111        '--auto-file-cache', default=True, type=bool,
112        help=('Automatically cache files appeared in the command line' +
113              'to local executor folder.' +
114              ' This will also cause rewritten of all the file names in the command,' +
115              ' e.g. `../../kmeans ../kmeans.conf` will be rewritten to `./kmeans kmeans.conf`'))
116    parser.add_argument('--files', default=[], action='append',
117                        help=('The cached file list which will be copied to local environment,' +
118                              ' You may need this option to cache additional files.' +
119                              ' You  --auto-file-cache is off'))
120    parser.add_argument('--archives', default=[], action='append',
121                        help=('Same as cached files,' +
122                              ' but corresponds to archieve files that will be unziped locally,' +
123                              ' You can use this option to ship python libraries.' +
124                              ' Only valid in yarn jobs.'))
125    parser.add_argument('--env', action='append', default=[],
126                        help='Client and ApplicationMaster environment variables.')
127    parser.add_argument('--yarn-app-classpath', type=str,
128                        help=('Explicit YARN ApplicationMaster classpath.' +
129                              'Can be used to override defaults.'))
130    parser.add_argument('--yarn-app-dir', type=str,
131                        default=os.path.join(os.path.dirname(__file__), os.pardir, 'yarn'),
132                        help=('Directory to YARN appmaster. Only used in YARN mode.'))
133    parser.add_argument('--mesos-master', type=str,
134                        help=('Mesos master, default to ${MESOS_MASTER}')),
135    parser.add_argument('--ship-libcxx', default=None, type=str,
136                        help=('The path to the customized gcc lib folder.' +
137                              'You can use this option to ship customized libstdc++' +
138                              ' library to the workers.'))
139    parser.add_argument('--sync-dst-dir', type=str,
140                        help = 'if specificed, it will sync the current \
141                        directory into remote machines\'s SYNC_DST_DIR')
142    parser.add_argument('command', nargs='+',
143                        help='Command to be launched')
144    parser.add_argument('--slurm-worker-nodes', default=None, type=int,
145                        help=('Number of nodes on which workers are run. Used only in SLURM mode.' +
146                              'If not explicitly set, it defaults to number of workers.'))
147    parser.add_argument('--slurm-server-nodes', default=None, type=int,
148                        help=('Number of nodes on which parameter servers are run. Used only in SLURM mode.' +
149                              'If not explicitly set, it defaults to number of parameter servers.'))
150    parser.add_argument('--kube-namespace', default="default", type=str,
151                        help=('A namespace in whitch all tasks are run. Used only in Kubernetes mode.' +
152                              'If not explicitly set, it defaults to default.'))
153    parser.add_argument('--kube-worker-image', default="mxnet/python", type=str,
154                        help=('Container image of workers. Used only in Kubernetes mode.' +
155                              'If not explicitly set, it defaults to mxnet/python.'))
156    parser.add_argument('--kube-server-image', default="mxnet/python", type=str,
157                        help=('Container image of servers. Used only in Kubernetes mode.' +
158                              'If not explicitly set, it defaults to mxnet/python.'))
159    parser.add_argument('--kube-worker-template', default=None, type=str,
160                        help=('Manifest template for workers. Used only in Kubernetes mode.' +
161                              'Can be used to override defaults.'))
162    parser.add_argument('--kube-server-template', default=None, type=str,
163                        help=('Manifest template for servers. Used only in Kubernetes mode.' +
164                              'Can be used to override defaults.'))
165    parser.add_argument('--local-num-attempt', default=0, type=int,
166                        help=('Number of attempt local tracker can restart slave.'))
167    (args, unknown) = parser.parse_known_args(args)
168    args.command += unknown
169
170    if args.cluster is None:
171        args.cluster = os.getenv('DMLC_SUBMIT_CLUSTER', None)
172
173    if args.cluster is None:
174        raise RuntimeError('--cluster is not specified, ' +
175                           'you can also specify the default behavior via ' +
176                           'environment variable DMLC_SUBMIT_CLUSTER')
177
178    args.worker_memory_mb = get_memory_mb(args.worker_memory)
179    args.server_memory_mb = get_memory_mb(args.server_memory)
180    return args
181