1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""export helper functions"""
18# coding: utf-8
19import os
20import logging
21import mxnet as mx
22
23
24def load_module(sym_filepath, params_filepath):
25    """Loads the MXNet model file and
26    returns MXNet symbol and params (weights).
27
28    Parameters
29    ----------
30    json_path : str
31        Path to the json file
32    params_path : str
33        Path to the params file
34
35    Returns
36    -------
37    sym : MXNet symbol
38        Model symbol object
39
40    params : params object
41        Model weights including both arg and aux params.
42    """
43    if not (os.path.isfile(sym_filepath) and os.path.isfile(params_filepath)):
44        raise ValueError("Symbol and params files provided are invalid")
45
46    try:
47        # reads symbol.json file from given path and
48        # retrieves model prefix and number of epochs
49        model_name = sym_filepath.rsplit('.', 1)[0].rsplit('-', 1)[0]
50        params_file_list = params_filepath.rsplit('.', 1)[0].rsplit('-', 1)
51        # Setting num_epochs to 0 if not present in filename
52        num_epochs = 0 if len(params_file_list) == 1 else int(params_file_list[1])
53    except IndexError:
54        logging.info("Model and params name should be in format: "
55                     "prefix-symbol.json, prefix-epoch.params")
56        raise
57
58    sym, arg_params, aux_params = mx.model.load_checkpoint(model_name, num_epochs)
59
60    # Merging arg and aux parameters
61    params = {}
62    params.update(arg_params)
63    params.update(aux_params)
64
65    return sym, params
66