1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18## based on https://github.com/apache/mxnet/issues/1302 19## Parses the model fit log file and generates a train/val vs epoch plot 20import matplotlib.pyplot as plt 21import numpy as np 22import re 23import argparse 24 25parser = argparse.ArgumentParser(description='Parses log file and generates train/val curves') 26parser.add_argument('--log-file', type=str,default="log_tr_va", 27 help='the path of log file') 28args = parser.parse_args() 29 30 31TR_RE = re.compile('.*?]\sTrain-accuracy=([\d\.]+)') 32VA_RE = re.compile('.*?]\sValidation-accuracy=([\d\.]+)') 33 34log = open(args.log_file).read() 35 36log_tr = [float(x) for x in TR_RE.findall(log)] 37log_va = [float(x) for x in VA_RE.findall(log)] 38idx = np.arange(len(log_tr)) 39 40plt.figure(figsize=(8, 6)) 41plt.xlabel("Epoch") 42plt.ylabel("Accuracy") 43plt.plot(idx, log_tr, 'o', linestyle='-', color="r", 44 label="Train accuracy") 45 46plt.plot(idx, log_va, 'o', linestyle='-', color="b", 47 label="Validation accuracy") 48 49plt.legend(loc="best") 50plt.xticks(np.arange(min(idx), max(idx)+1, 5)) 51plt.yticks(np.arange(0, 1, 0.2)) 52plt.ylim([0,1]) 53plt.show() 54