1from __future__ import absolute_import, division 2 3from copy import copy 4from functools import partial 5 6from .auto import tqdm as tqdm_auto 7 8try: 9 import keras 10except (ImportError, AttributeError) as e: 11 try: 12 from tensorflow import keras 13 except ImportError: 14 raise e 15__author__ = {"github.com/": ["casperdcl"]} 16__all__ = ['TqdmCallback'] 17 18 19class TqdmCallback(keras.callbacks.Callback): 20 """Keras callback for epoch and batch progress.""" 21 @staticmethod 22 def bar2callback(bar, pop=None, delta=(lambda logs: 1)): 23 def callback(_, logs=None): 24 n = delta(logs) 25 if logs: 26 if pop: 27 logs = copy(logs) 28 [logs.pop(i, 0) for i in pop] 29 bar.set_postfix(logs, refresh=False) 30 bar.update(n) 31 32 return callback 33 34 def __init__(self, epochs=None, data_size=None, batch_size=None, verbose=1, 35 tqdm_class=tqdm_auto, **tqdm_kwargs): 36 """ 37 Parameters 38 ---------- 39 epochs : int, optional 40 data_size : int, optional 41 Number of training pairs. 42 batch_size : int, optional 43 Number of training pairs per batch. 44 verbose : int 45 0: epoch, 1: batch (transient), 2: batch. [default: 1]. 46 Will be set to `0` unless both `data_size` and `batch_size` 47 are given. 48 tqdm_class : optional 49 `tqdm` class to use for bars [default: `tqdm.auto.tqdm`]. 50 tqdm_kwargs : optional 51 Any other arguments used for all bars. 52 """ 53 if tqdm_kwargs: 54 tqdm_class = partial(tqdm_class, **tqdm_kwargs) 55 self.tqdm_class = tqdm_class 56 self.epoch_bar = tqdm_class(total=epochs, unit='epoch') 57 self.on_epoch_end = self.bar2callback(self.epoch_bar) 58 if data_size and batch_size: 59 self.batches = batches = (data_size + batch_size - 1) // batch_size 60 else: 61 self.batches = batches = None 62 self.verbose = verbose 63 if verbose == 1: 64 self.batch_bar = tqdm_class(total=batches, unit='batch', leave=False) 65 self.on_batch_end = self.bar2callback( 66 self.batch_bar, pop=['batch', 'size'], 67 delta=lambda logs: logs.get('size', 1)) 68 69 def on_train_begin(self, *_, **__): 70 params = self.params.get 71 auto_total = params('epochs', params('nb_epoch', None)) 72 if auto_total is not None and auto_total != self.epoch_bar.total: 73 self.epoch_bar.reset(total=auto_total) 74 75 def on_epoch_begin(self, epoch, *_, **__): 76 if self.epoch_bar.n < epoch: 77 ebar = self.epoch_bar 78 ebar.n = ebar.last_print_n = ebar.initial = epoch 79 if self.verbose: 80 params = self.params.get 81 total = params('samples', params( 82 'nb_sample', params('steps', None))) or self.batches 83 if self.verbose == 2: 84 if hasattr(self, 'batch_bar'): 85 self.batch_bar.close() 86 self.batch_bar = self.tqdm_class( 87 total=total, unit='batch', leave=True, 88 unit_scale=1 / (params('batch_size', 1) or 1)) 89 self.on_batch_end = self.bar2callback( 90 self.batch_bar, pop=['batch', 'size'], 91 delta=lambda logs: logs.get('size', 1)) 92 elif self.verbose == 1: 93 self.batch_bar.unit_scale = 1 / (params('batch_size', 1) or 1) 94 self.batch_bar.reset(total=total) 95 else: 96 raise KeyError('Unknown verbosity') 97 98 def on_train_end(self, *_, **__): 99 if self.verbose: 100 self.batch_bar.close() 101 self.epoch_bar.close() 102 103 def display(self): 104 """Displays in the current cell in Notebooks.""" 105 container = getattr(self.epoch_bar, 'container', None) 106 if container is None: 107 return 108 from .notebook import display 109 display(container) 110 batch_bar = getattr(self, 'batch_bar', None) 111 if batch_bar is not None: 112 display(batch_bar.container) 113 114 @staticmethod 115 def _implements_train_batch_hooks(): 116 return True 117 118 @staticmethod 119 def _implements_test_batch_hooks(): 120 return True 121 122 @staticmethod 123 def _implements_predict_batch_hooks(): 124 return True 125