from __future__ import annotations
import atexit
import contextlib
import datetime
import io
import logging
import os
import re
import sys
from collections.abc import Iterable, Iterator
from types import TracebackType
from python_utils import types
from python_utils.converters import scale_1024
from python_utils.terminal import get_terminal_size
from python_utils.time import epoch, format_time, timedelta_to_seconds
from progressbar import base, env, terminal
if types.TYPE_CHECKING:
from .bar import ProgressBar, ProgressBarMixinBase
# Make sure these are available for import
assert timedelta_to_seconds is not None
assert get_terminal_size is not None
assert format_time is not None
assert scale_1024 is not None
assert epoch is not None
StringT = types.TypeVar('StringT', bound=types.StringTypes)
[docs]
def deltas_to_seconds(
*deltas: None | datetime.timedelta | float,
default: types.Optional[types.Type[ValueError]] = ValueError,
) -> int | float | None:
"""
Convert timedeltas and seconds as int to seconds as float while coalescing.
>>> deltas_to_seconds(datetime.timedelta(seconds=1, milliseconds=234))
1.234
>>> deltas_to_seconds(123)
123.0
>>> deltas_to_seconds(1.234)
1.234
>>> deltas_to_seconds(None, 1.234)
1.234
>>> deltas_to_seconds(0, 1.234)
0.0
>>> deltas_to_seconds()
Traceback (most recent call last):
...
ValueError: No valid deltas passed to `deltas_to_seconds`
>>> deltas_to_seconds(None)
Traceback (most recent call last):
...
ValueError: No valid deltas passed to `deltas_to_seconds`
>>> deltas_to_seconds(default=0.0)
0.0
"""
for delta in deltas:
if delta is None:
continue
if isinstance(delta, datetime.timedelta):
return timedelta_to_seconds(delta)
elif not isinstance(delta, float):
return float(delta)
else:
return delta
if default is ValueError:
raise ValueError('No valid deltas passed to `deltas_to_seconds`')
else:
# mypy doesn't understand the `default is ValueError` check
return default # type: ignore
[docs]
def no_color(value: StringT) -> StringT:
"""
Return the `value` without ANSI escape codes.
>>> no_color(b'\u001b[1234]abc')
b'abc'
>>> str(no_color('\u001b[1234]abc'))
'abc'
>>> str(no_color('\u001b[1234]abc'))
'abc'
>>> no_color(123)
Traceback (most recent call last):
...
TypeError: `value` must be a string or bytes, got 123
"""
if isinstance(value, bytes):
pattern: bytes = bytes(terminal.ESC, 'ascii') + b'\\[.*?[@-~]'
return re.sub(pattern, b'', value) # type: ignore
elif isinstance(value, str):
return re.sub('\x1b\\[.*?[@-~]', '', value) # type: ignore
else:
raise TypeError(f'`value` must be a string or bytes, got {value!r}')
[docs]
def len_color(value: types.StringTypes) -> int:
"""
Return the length of `value` without ANSI escape codes.
>>> len_color(b'\u001b[1234]abc')
3
>>> len_color('\u001b[1234]abc')
3
>>> len_color('\u001b[1234]abc')
3
"""
return len(no_color(value))
[docs]
class WrappingIO:
buffer: io.StringIO
target: base.IO
capturing: bool
listeners: set
needs_clear: bool = False
def __init__(
self,
target: base.IO,
capturing: bool = False,
listeners: types.Optional[types.Set[ProgressBar]] = None,
) -> None:
self.buffer = io.StringIO()
self.target = target
self.capturing = capturing
self.listeners = listeners or set()
self.needs_clear = False
[docs]
def write(self, value: str) -> int:
ret = 0
if self.capturing:
ret += self.buffer.write(value)
if '\n' in value: # pragma: no branch
self.needs_clear = True
for listener in self.listeners: # pragma: no branch
listener.update()
else:
ret += self.target.write(value)
if '\n' in value: # pragma: no branch
self.flush_target()
return ret
[docs]
def flush(self) -> None:
self.buffer.flush()
def _flush(self) -> None:
if value := self.buffer.getvalue():
self.flush()
self.target.write(value)
self.buffer.seek(0)
self.buffer.truncate(0)
self.needs_clear = False
# when explicitly flushing, always flush the target as well
self.flush_target()
[docs]
def flush_target(self) -> None: # pragma: no cover
if not self.target.closed and getattr(self.target, 'flush', None):
self.target.flush()
def __enter__(self) -> WrappingIO:
return self
[docs]
def fileno(self) -> int:
return self.target.fileno()
[docs]
def isatty(self) -> bool:
return self.target.isatty()
[docs]
def read(self, n: int = -1) -> str:
return self.target.read(n)
[docs]
def readable(self) -> bool:
return self.target.readable()
[docs]
def readline(self, limit: int = -1) -> str:
return self.target.readline(limit)
[docs]
def readlines(self, hint: int = -1) -> list[str]:
return self.target.readlines(hint)
[docs]
def seek(self, offset: int, whence: int = os.SEEK_SET) -> int:
return self.target.seek(offset, whence)
[docs]
def seekable(self) -> bool:
return self.target.seekable()
[docs]
def tell(self) -> int:
return self.target.tell()
[docs]
def truncate(self, size: types.Optional[int] = None) -> int:
return self.target.truncate(size)
[docs]
def writable(self) -> bool:
return self.target.writable()
[docs]
def writelines(self, lines: Iterable[str]) -> None:
return self.target.writelines(lines)
[docs]
def close(self) -> None:
self.flush()
self.target.close()
def __next__(self) -> str:
return self.target.__next__()
def __iter__(self) -> Iterator[str]:
return self.target.__iter__()
def __exit__(
self,
__t: type[BaseException] | None,
__value: BaseException | None,
__traceback: TracebackType | None,
) -> None:
self.close()
[docs]
class StreamWrapper:
"""Wrap stdout and stderr globally."""
stdout: base.TextIO | WrappingIO
stderr: base.TextIO | WrappingIO
original_excepthook: types.Callable[
[
types.Type[BaseException],
BaseException,
TracebackType | None,
],
None,
]
wrapped_stdout: int = 0
wrapped_stderr: int = 0
wrapped_excepthook: int = 0
capturing: int = 0
listeners: set
def __init__(self) -> None:
self.stdout = self.original_stdout = sys.stdout
self.stderr = self.original_stderr = sys.stderr
self.original_excepthook = sys.excepthook
self.wrapped_stdout = 0
self.wrapped_stderr = 0
self.wrapped_excepthook = 0
self.capturing = 0
self.listeners = set()
if env.env_flag('WRAP_STDOUT', default=False): # pragma: no cover
self.wrap_stdout()
if env.env_flag('WRAP_STDERR', default=False): # pragma: no cover
self.wrap_stderr()
[docs]
def start_capturing(self, bar: ProgressBarMixinBase | None = None) -> None:
if bar: # pragma: no branch
self.listeners.add(bar)
self.capturing += 1
self.update_capturing()
[docs]
def stop_capturing(self, bar: ProgressBarMixinBase | None = None) -> None:
if bar: # pragma: no branch
with contextlib.suppress(KeyError):
self.listeners.remove(bar)
self.capturing -= 1
self.update_capturing()
[docs]
def update_capturing(self) -> None: # pragma: no cover
if isinstance(self.stdout, WrappingIO):
self.stdout.capturing = self.capturing > 0
if isinstance(self.stderr, WrappingIO):
self.stderr.capturing = self.capturing > 0
if self.capturing <= 0:
self.flush()
[docs]
def wrap(self, stdout: bool = False, stderr: bool = False) -> None:
if stdout:
self.wrap_stdout()
if stderr:
self.wrap_stderr()
[docs]
def wrap_stdout(self) -> WrappingIO:
self.wrap_excepthook()
if not self.wrapped_stdout:
self.stdout = sys.stdout = WrappingIO( # type: ignore
self.original_stdout,
listeners=self.listeners,
)
self.wrapped_stdout += 1
return sys.stdout # type: ignore
[docs]
def wrap_stderr(self) -> WrappingIO:
self.wrap_excepthook()
if not self.wrapped_stderr:
self.stderr = sys.stderr = WrappingIO( # type: ignore
self.original_stderr,
listeners=self.listeners,
)
self.wrapped_stderr += 1
return sys.stderr # type: ignore
[docs]
def unwrap_excepthook(self) -> None:
if self.wrapped_excepthook:
self.wrapped_excepthook -= 1
sys.excepthook = self.original_excepthook
[docs]
def wrap_excepthook(self) -> None:
if not self.wrapped_excepthook:
logger.debug('wrapping excepthook')
self.wrapped_excepthook += 1
sys.excepthook = self.excepthook
[docs]
def unwrap(self, stdout: bool = False, stderr: bool = False) -> None:
if stdout:
self.unwrap_stdout()
if stderr:
self.unwrap_stderr()
[docs]
def unwrap_stdout(self) -> None:
if self.wrapped_stdout > 1:
self.wrapped_stdout -= 1
else:
sys.stdout = self.original_stdout
self.wrapped_stdout = 0
[docs]
def unwrap_stderr(self) -> None:
if self.wrapped_stderr > 1:
self.wrapped_stderr -= 1
else:
sys.stderr = self.original_stderr
self.wrapped_stderr = 0
[docs]
def needs_clear(self) -> bool: # pragma: no cover
stdout_needs_clear = getattr(self.stdout, 'needs_clear', False)
stderr_needs_clear = getattr(self.stderr, 'needs_clear', False)
return stderr_needs_clear or stdout_needs_clear
[docs]
def flush(self) -> None:
if self.wrapped_stdout and isinstance(self.stdout, WrappingIO):
try:
self.stdout._flush()
except io.UnsupportedOperation: # pragma: no cover
self.wrapped_stdout = False
logger.warning(
'Disabling stdout redirection, %r is not seekable',
sys.stdout,
)
if self.wrapped_stderr and isinstance(self.stderr, WrappingIO):
try:
self.stderr._flush()
except io.UnsupportedOperation: # pragma: no cover
self.wrapped_stderr = False
logger.warning(
'Disabling stderr redirection, %r is not seekable',
sys.stderr,
)
[docs]
def excepthook(
self,
exc_type: type[BaseException],
exc_value: BaseException,
exc_traceback: types.TracebackType | None,
) -> None:
self.original_excepthook(exc_type, exc_value, exc_traceback)
self.flush()
[docs]
class AttributeDict(dict):
"""
A dict that can be accessed with .attribute.
>>> attrs = AttributeDict(spam=123)
# Reading
>>> attrs['spam']
123
>>> attrs.spam
123
# Read after update using attribute
>>> attrs.spam = 456
>>> attrs['spam']
456
>>> attrs.spam
456
# Read after update using dict access
>>> attrs['spam'] = 123
>>> attrs['spam']
123
>>> attrs.spam
123
# Read after update using dict access
>>> del attrs.spam
>>> attrs['spam']
Traceback (most recent call last):
...
KeyError: 'spam'
>>> attrs.spam
Traceback (most recent call last):
...
AttributeError: No such attribute: spam
>>> del attrs.spam
Traceback (most recent call last):
...
AttributeError: No such attribute: spam
"""
def __getattr__(self, name: str) -> int:
if name in self:
return self[name]
else:
raise AttributeError(f'No such attribute: {name}')
def __setattr__(self, name: str, value: int) -> None:
self[name] = value
def __delattr__(self, name: str) -> None:
if name in self:
del self[name]
else:
raise AttributeError(f'No such attribute: {name}')
logger: logging.Logger = logging.getLogger(__name__)
streams = StreamWrapper()
atexit.register(streams.flush)