1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17"""export helper functions""" 18# coding: utf-8 19import os 20import logging 21import mxnet as mx 22 23 24def load_module(sym_filepath, params_filepath): 25 """Loads the MXNet model file and 26 returns MXNet symbol and params (weights). 27 28 Parameters 29 ---------- 30 json_path : str 31 Path to the json file 32 params_path : str 33 Path to the params file 34 35 Returns 36 ------- 37 sym : MXNet symbol 38 Model symbol object 39 40 params : params object 41 Model weights including both arg and aux params. 42 """ 43 if not (os.path.isfile(sym_filepath) and os.path.isfile(params_filepath)): 44 raise ValueError("Symbol and params files provided are invalid") 45 46 try: 47 # reads symbol.json file from given path and 48 # retrieves model prefix and number of epochs 49 model_name = sym_filepath.rsplit('.', 1)[0].rsplit('-', 1)[0] 50 params_file_list = params_filepath.rsplit('.', 1)[0].rsplit('-', 1) 51 # Setting num_epochs to 0 if not present in filename 52 num_epochs = 0 if len(params_file_list) == 1 else int(params_file_list[1]) 53 except IndexError: 54 logging.info("Model and params name should be in format: " 55 "prefix-symbol.json, prefix-epoch.params") 56 raise 57 58 sym, arg_params, aux_params = mx.model.load_checkpoint(model_name, num_epochs) 59 60 # Merging arg and aux parameters 61 params = {} 62 params.update(arg_params) 63 params.update(aux_params) 64 65 return sym, params 66