1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18## based on https://github.com/apache/mxnet/issues/1302
19## Parses the model fit log file and generates a train/val vs epoch plot
20import matplotlib.pyplot as plt
21import numpy as np
22import re
23import argparse
24
25parser = argparse.ArgumentParser(description='Parses log file and generates train/val curves')
26parser.add_argument('--log-file', type=str,default="log_tr_va",
27                    help='the path of log file')
28args = parser.parse_args()
29
30
31TR_RE = re.compile('.*?]\sTrain-accuracy=([\d\.]+)')
32VA_RE = re.compile('.*?]\sValidation-accuracy=([\d\.]+)')
33
34log = open(args.log_file).read()
35
36log_tr = [float(x) for x in TR_RE.findall(log)]
37log_va = [float(x) for x in VA_RE.findall(log)]
38idx = np.arange(len(log_tr))
39
40plt.figure(figsize=(8, 6))
41plt.xlabel("Epoch")
42plt.ylabel("Accuracy")
43plt.plot(idx, log_tr, 'o', linestyle='-', color="r",
44         label="Train accuracy")
45
46plt.plot(idx, log_va, 'o', linestyle='-', color="b",
47         label="Validation accuracy")
48
49plt.legend(loc="best")
50plt.xticks(np.arange(min(idx), max(idx)+1, 5))
51plt.yticks(np.arange(0, 1, 0.2))
52plt.ylim([0,1])
53plt.show()
54