ai-content-maker/.venv/Lib/site-packages/proglog/proglog.py

392 lines
13 KiB
Python

"""Implements the generic progress logger class, and the ProgressBar class.
"""
from tqdm import tqdm, tqdm_notebook
from collections import OrderedDict
import time
SETTINGS = {
'notebook': False
}
def notebook(turn='on'):
SETTINGS['notebook'] = True if (turn == 'on') else False
def troncate_string(s, max_length=25):
return s if (len(s) < max_length) else (s[:max_length] + "...")
class ProgressLogger:
"""Generic class for progress loggers.
A progress logger contains a "state" dictionnary.
Parameters
----------
init_state
Dictionnary representing the initial state.
"""
def __init__(self, init_state=None):
self.state = {}
self.stored = {}
self.logs = []
self.log_indent = 0
if init_state is not None:
self.state.update(init_state)
def log(self, message):
self.logs.append((' ' * self.log_indent) + message)
def dump_logs(self, filepath=None):
if filepath is not None:
with open(filepath, 'a') as f:
f.write("\n".join(self.logs))
else:
return "\n".join(self.logs)
def callback(self, **kw):
"""Execute something after the state has been updated by the given
state elements.
This default callback does nothing, overwrite it by subclassing
"""
pass
def store(self, **kw):
"""Store objects in the logger and trigger ``self.store_callback``.
This works exactly like ``logger()``, but the later is meant for simple
data objects (text, numbers) that will be sent over the network or
written to a file. The ``store`` method expects rather large objects
which are not necessarily serializable, and will be used eg to draw
plots on the fly.
"""
self.stored.update(kw)
self.store_callback(**kw)
def store_callback(self, **kw):
"""Execute something after the store has been updated by the given
state elements.
This default callback does nothing, overwrite it by subclassing
"""
pass
def iter(self, **kw):
"""Iterate through a list while updating the state.
Examples
--------
>>> for username in logger.iter(user=['tom', 'tim', 'lea']:
>>> # At every loop, logger.state['user'] is updated
>>> print (username)
"""
for field, iterable in kw.items():
for it in iterable:
self(**{field: it})
yield it
def __call__(self, **kw):
self.state.update(kw)
self.callback(**kw)
class ProgressBarLogger(ProgressLogger):
"""Generic class for progress loggers.
A progress logger contains a "state" dictionnary
Parameters
----------
init_state
Initial state of the logger
bars
Either None (will be initialized with no bar) or a list/tuple of bar
names (``['main', 'sub']``) which will be initialized with index -1 and
no total, or a dictionary (possibly ordered) of bars, of the form
``{bar_1: {title: 'bar1', index: 2, total:23}, bar_2: {...}}``
ignored_bars
Either None (newly met bars will be added) or a list of blacklisted bar
names, or ``'all_others'`` to signify that all bar names not already in
``self.bars`` will be ignored.
"""
bar_indent = 2
def __init__(self, init_state=None, bars=None, ignored_bars=None,
logged_bars='all', min_time_interval=0, ignore_bars_under=0):
ProgressLogger.__init__(self, init_state)
if bars is None:
bars = OrderedDict()
elif isinstance(bars, (list, tuple)):
bars = OrderedDict([
(b, dict(title=b, index=-1, total=None, message=None,
indent=0))
for b in bars
])
if isinstance(ignored_bars, (list, tuple)):
ignored_bars = set(ignored_bars)
self.ignored_bars = ignored_bars
self.logged_bars = logged_bars
self.state['bars'] = bars
self.min_time_interval = min_time_interval
self.ignore_bars_under = ignore_bars_under
@property
def bars(self):
"""Return ``self.state['bars'].``"""
return self.state['bars']
def bar_is_ignored(self, bar):
if self.ignored_bars is None:
return False
elif self.ignored_bars == 'all_others':
return (bar not in self.bars)
else:
return bar in self.ignored_bars
def bar_is_logged(self, bar):
if (not self.logged_bars):
return False
elif self.logged_bars == 'all':
return True
else:
return bar in self.logged_bars
def iterable_is_too_short(self, iterable):
length = len(iterable) if hasattr(iterable, '__len__') else None
return (length is not None) and (length < self.ignore_bars_under)
def iter_bar(self, bar_prefix='', **kw):
"""Iterate through a list while updating a state bar.
Examples
--------
>>> for username in logger.iter_bar(user=['tom', 'tim', 'lea']):
>>> # At every loop, logger.state['bars']['user'] is updated
>>> # to {index: i, total: 3, title:'user'}
>>> print (username)
"""
if 'bar_message' in kw:
bar_message = kw.pop('bar_message')
else:
bar_message = None
bar, iterable = kw.popitem()
if self.bar_is_ignored(bar) or self.iterable_is_too_short(iterable):
return iterable
bar = bar_prefix + bar
if hasattr(iterable, '__len__'):
self(**{bar + '__total': len(iterable)})
def new_iterable():
last_time = time.time()
i = 0 # necessary in case the iterator is empty
for i, it in enumerate(iterable):
now_time = time.time()
if (i == 0) or (now_time - last_time > self.min_time_interval):
if bar_message is not None:
self(**{bar + '__message': bar_message(it)})
self(**{bar + '__index': i})
last_time = now_time
yield it
if self.bars[bar]['index'] != i:
self(**{bar + '__index': i})
self(**{bar + '__index': i + 1})
return new_iterable()
def bars_callback(self, bar, attr, value, old_value=None):
"""Execute a custom action after the progress bars are updated.
Parameters
----------
bar
Name/ID of the bar to be modified.
attr
Attribute of the bar attribute to be modified
value
New value of the attribute
old_value
Previous value of this bar's attribute.
This default callback does nothing, overwrite it by subclassing.
"""
pass
def __call__(self, **kw):
items = sorted(kw.items(), key=lambda kv: not kv[0].endswith('total'))
for key, value in items:
if '__' in key:
bar, attr = key.split('__')
if self.bar_is_ignored(bar):
continue
kw.pop(key)
if bar not in self.bars:
self.bars[bar] = dict(title=bar, index=-1,
total=None, message=None)
old_value = self.bars[bar][attr]
if self.bar_is_logged(bar):
new_bar = (attr == 'index') and (value < old_value)
if (attr == 'total') or (new_bar):
self.bars[bar]['indent'] = self.log_indent
else:
self.log_indent = self.bars[bar]['indent']
self.log("[%s] %s: %s" % (bar, attr, value))
self.log_indent += self.bar_indent
self.bars[bar][attr] = value
self.bars_callback(bar, attr, value, old_value)
self.state.update(kw)
self.callback(**kw)
class TqdmProgressBarLogger(ProgressBarLogger):
"""Tqdm-powered progress bar for console or Notebooks.
Parameters
----------
init_state
Initial state of the logger
bars
Either None (will be initialized with no bar) or a list/tuple of bar
names (``['main', 'sub']``) which will be initialized with index -1 and
no total, or a dictionary (possibly ordered) of bars, of the form
``{bar_1: {title: 'bar1', index: 2, total:23}, bar_2: {...}}``
ignored_bars
Either None (newly met bars will be added) or a list of blacklisted bar
names, or ``'all_others'`` to signify that all bar names not already in
``self.bars`` will be ignored.
leave_bars
notebook
True will make the bars look nice (HTML) in the jupyter notebook. It is
advised to leave to 'default' as the default can be globally set from
inside a notebook with ``import proglog; proglog.notebook_mode()``.
print_messages
If True, every ``logger(message='something')`` will print a message in
the console / notebook
"""
def __init__(self, init_state=None, bars=None, leave_bars=False,
ignored_bars=None, logged_bars='all', notebook='default',
print_messages=True, min_time_interval=0,
ignore_bars_under=0):
ProgressBarLogger.__init__(self, init_state=init_state, bars=bars,
ignored_bars=ignored_bars,
logged_bars=logged_bars,
ignore_bars_under=ignore_bars_under,
min_time_interval=min_time_interval)
self.leave_bars = leave_bars
self.tqdm_bars = OrderedDict([
(bar, None)
for bar in self.bars
])
if notebook == 'default':
notebook = SETTINGS['notebook']
self.notebook = notebook
self.print_messages = print_messages
self.tqdm = (tqdm_notebook if self.notebook else tqdm)
def new_tqdm_bar(self, bar):
"""Create a new tqdm bar, possibly replacing an existing one."""
if (bar in self.tqdm_bars) and (self.tqdm_bars[bar] is not None):
self.close_tqdm_bar(bar)
infos = self.bars[bar]
self.tqdm_bars[bar] = self.tqdm(
total=infos['total'],
desc=infos['title'],
postfix=dict(now=troncate_string(str(infos['message']))),
leave=self.leave_bars
)
def close_tqdm_bar(self, bar):
"""Close and erase the tqdm bar"""
self.tqdm_bars[bar].close()
if not self.notebook:
self.tqdm_bars[bar] = None
def bars_callback(self, bar, attr, value, old_value):
if (bar not in self.tqdm_bars) or (self.tqdm_bars[bar] is None):
self.new_tqdm_bar(bar)
if attr == 'index':
if value >= old_value:
total = self.bars[bar]['total']
if total and (value >= total):
self.close_tqdm_bar(bar)
else:
self.tqdm_bars[bar].update(value - old_value)
else:
self.new_tqdm_bar(bar)
self.tqdm_bars[bar].update(value + 1)
elif attr == 'message':
self.tqdm_bars[bar].set_postfix(now=troncate_string(str(value)))
self.tqdm_bars[bar].update(0)
def callback(self, **kw):
if self.print_messages and ('message' in kw) and kw['message']:
if self.notebook:
print(kw['message'])
else:
self.tqdm.write(kw['message'])
class RqWorkerProgressLogger:
def __init__(self, job):
self.job = job
if 'progress_data' not in self.job.meta:
self.job.meta['progress_data'] = {}
self.job.save()
def callback(self, **kw):
self.job.meta['progress_data'] = self.state
self.job.save()
class RqWorkerBarLogger(RqWorkerProgressLogger, ProgressBarLogger):
def __init__(self, job, init_state=None, bars=None, ignored_bars=(),
logged_bars='all', min_time_interval=0):
RqWorkerProgressLogger.__init__(self, job)
ProgressBarLogger.__init__(self, init_state=init_state, bars=bars,
ignored_bars=ignored_bars,
logged_bars=logged_bars,
min_time_interval=min_time_interval)
class MuteProgressBarLogger(ProgressBarLogger):
def bar_is_ignored(self, bar):
return True
def default_bar_logger(logger, bars=None, ignored_bars=None, logged_bars='all',
min_time_interval=0, ignore_bars_under=0):
if logger == 'bar':
return TqdmProgressBarLogger(
bars=bars,
ignored_bars=ignored_bars,
logged_bars=logged_bars,
min_time_interval=min_time_interval,
ignore_bars_under=ignore_bars_under
)
elif logger is None:
return MuteProgressBarLogger()
else:
return logger