1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18import argparse
19import logging
20import mxnet as mx
21import numpy as np
22from data import get_movielens_iter, get_movielens_data
23from model import matrix_fact_net
24import os
25
26logging.basicConfig(level=logging.DEBUG)
27
28parser = argparse.ArgumentParser(description="Run matrix factorization with sparse embedding",
29                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
30parser.add_argument('--num-epoch', type=int, default=3,
31                    help='number of epochs to train')
32parser.add_argument('--seed', type=int, default=1,
33                    help='random seed')
34parser.add_argument('--batch-size', type=int, default=128,
35                    help='number of examples per batch')
36parser.add_argument('--log-interval', type=int, default=100,
37                    help='logging interval')
38parser.add_argument('--factor-size', type=int, default=128,
39                    help="the factor size of the embedding operation")
40parser.add_argument('--gpus', type=str,
41                    help="list of gpus to run, e.g. 0 or 0,2. empty means using cpu().")
42parser.add_argument('--dense', action='store_true', help="whether to use dense embedding")
43
44MOVIELENS = {
45    'dataset': 'ml-10m',
46    'train': './data/ml-10M100K/r1.train',
47    'val': './data/ml-10M100K/r1.test',
48    'max_user': 71569,
49    'max_movie': 65135,
50}
51
52def batch_row_ids(data_batch):
53    """ Generate row ids based on the current mini-batch """
54    item = data_batch.data[0]
55    user = data_batch.data[1]
56    return {'user_weight': user.astype(np.int64),
57            'item_weight': item.astype(np.int64)}
58
59def all_row_ids(data_batch):
60    """ Generate row ids for all rows """
61    all_users = mx.nd.arange(0, MOVIELENS['max_user'], dtype='int64')
62    all_movies = mx.nd.arange(0, MOVIELENS['max_movie'], dtype='int64')
63    return {'user_weight': all_users, 'item_weight': all_movies}
64
65if __name__ == '__main__':
66    head = '%(asctime)-15s %(message)s'
67    logging.basicConfig(level=logging.INFO, format=head)
68
69    # arg parser
70    args = parser.parse_args()
71    logging.info(args)
72    num_epoch = args.num_epoch
73    batch_size = args.batch_size
74    optimizer = 'sgd'
75    factor_size = args.factor_size
76    log_interval = args.log_interval
77
78    momentum = 0.9
79    ctx = [mx.gpu(int(i)) for i in args.gpus.split(',')] if args.gpus else [mx.cpu()]
80    learning_rate = 0.1
81    mx.random.seed(args.seed)
82    np.random.seed(args.seed)
83
84    # prepare dataset and iterators
85    max_user = MOVIELENS['max_user']
86    max_movies = MOVIELENS['max_movie']
87    data_dir = os.path.join(os.getcwd(), 'data')
88    get_movielens_data(data_dir, MOVIELENS['dataset'])
89    train_iter = get_movielens_iter(MOVIELENS['train'], batch_size)
90    val_iter = get_movielens_iter(MOVIELENS['val'], batch_size)
91
92    # construct the model
93    net = matrix_fact_net(factor_size, factor_size, max_user, max_movies, dense=args.dense)
94
95    # initialize the module
96    mod = mx.module.Module(net, context=ctx, data_names=['user', 'item'],
97                           label_names=['score'])
98    mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)
99    mod.init_params(initializer=mx.init.Xavier(factor_type="in", magnitude=2.34))
100    optim = mx.optimizer.create(optimizer, learning_rate=learning_rate,
101                                rescale_grad=1.0/batch_size)
102    mod.init_optimizer(optimizer=optim, kvstore='device')
103    # use MSE as the metric
104    metric = mx.metric.create(['MSE'])
105    speedometer = mx.callback.Speedometer(batch_size, log_interval)
106    logging.info('Training started ...')
107    for epoch in range(num_epoch):
108        nbatch = 0
109        metric.reset()
110        for batch in train_iter:
111            nbatch += 1
112            mod.prepare(batch, sparse_row_id_fn=batch_row_ids)
113            mod.forward_backward(batch)
114            # update all parameters
115            mod.update()
116            # update training metric
117            mod.update_metric(metric, batch.label)
118            speedometer_param = mx.model.BatchEndParam(epoch=epoch, nbatch=nbatch,
119                                                       eval_metric=metric, locals=locals())
120            speedometer(speedometer_param)
121
122        # prepare the module weight with all row ids for inference. Alternatively, one could call
123        # score = mod.score(val_iter, ['MSE'], sparse_row_id_fn=batch_row_ids)
124        # to fetch the weight per mini-batch
125        mod.prepare(None, sparse_row_id_fn=all_row_ids)
126        # evaluate metric on validation dataset
127        score = mod.score(val_iter, ['MSE'])
128        logging.info('epoch %d, eval MSE = %s ' % (epoch, score[0][1]))
129        # reset the iterator for next pass of data
130        train_iter.reset()
131        val_iter.reset()
132    logging.info('Training completed.')
133