import contextlib import dataclasses import math import textwrap from typing import Any, Dict, Optional import torch from torch import inf @dataclasses.dataclass class __PrinterOptions: precision: int = 4 threshold: float = 1000 edgeitems: int = 3 linewidth: int = 80 sci_mode: Optional[bool] = None PRINT_OPTS = __PrinterOptions() # We could use **kwargs, but this will give better docs def set_printoptions( precision=None, threshold=None, edgeitems=None, linewidth=None, profile=None, sci_mode=None, ): r"""Set options for printing. Items shamelessly taken from NumPy Args: precision: Number of digits of precision for floating point output (default = 4). threshold: Total number of array elements which trigger summarization rather than full `repr` (default = 1000). edgeitems: Number of array items in summary at beginning and end of each dimension (default = 3). linewidth: The number of characters per line for the purpose of inserting line breaks (default = 80). Thresholded matrices will ignore this parameter. profile: Sane defaults for pretty printing. Can override with any of the above options. (any one of `default`, `short`, `full`) sci_mode: Enable (True) or disable (False) scientific notation. If None (default) is specified, the value is defined by `torch._tensor_str._Formatter`. This value is automatically chosen by the framework. Example:: >>> # Limit the precision of elements >>> torch.set_printoptions(precision=2) >>> torch.tensor([1.12345]) tensor([1.12]) >>> # Limit the number of elements shown >>> torch.set_printoptions(threshold=5) >>> torch.arange(10) tensor([0, 1, 2, ..., 7, 8, 9]) >>> # Restore defaults >>> torch.set_printoptions(profile='default') >>> torch.tensor([1.12345]) tensor([1.1235]) >>> torch.arange(10) tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) """ if profile is not None: if profile == "default": PRINT_OPTS.precision = 4 PRINT_OPTS.threshold = 1000 PRINT_OPTS.edgeitems = 3 PRINT_OPTS.linewidth = 80 elif profile == "short": PRINT_OPTS.precision = 2 PRINT_OPTS.threshold = 1000 PRINT_OPTS.edgeitems = 2 PRINT_OPTS.linewidth = 80 elif profile == "full": PRINT_OPTS.precision = 4 PRINT_OPTS.threshold = inf PRINT_OPTS.edgeitems = 3 PRINT_OPTS.linewidth = 80 if precision is not None: PRINT_OPTS.precision = precision if threshold is not None: PRINT_OPTS.threshold = threshold if edgeitems is not None: PRINT_OPTS.edgeitems = edgeitems if linewidth is not None: PRINT_OPTS.linewidth = linewidth PRINT_OPTS.sci_mode = sci_mode def get_printoptions() -> Dict[str, Any]: r"""Gets the current options for printing, as a dictionary that can be passed as ``**kwargs`` to set_printoptions(). """ return dataclasses.asdict(PRINT_OPTS) @contextlib.contextmanager def printoptions(**kwargs): r"""Context manager that temporarily changes the print options. Accepted arguments are same as :func:`set_printoptions`.""" old_kwargs = get_printoptions() set_printoptions(**kwargs) try: yield finally: set_printoptions(**old_kwargs) def tensor_totype(t): dtype = torch.float if t.is_mps else torch.double return t.to(dtype=dtype) class _Formatter: def __init__(self, tensor): self.floating_dtype = tensor.dtype.is_floating_point self.int_mode = True self.sci_mode = False self.max_width = 1 with torch.no_grad(): tensor_view = tensor.reshape(-1) if not self.floating_dtype: for value in tensor_view: value_str = f"{value}" self.max_width = max(self.max_width, len(value_str)) else: nonzero_finite_vals = torch.masked_select( tensor_view, torch.isfinite(tensor_view) & tensor_view.ne(0) ) if nonzero_finite_vals.numel() == 0: # no valid number, do nothing return # Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU. nonzero_finite_abs = tensor_totype(nonzero_finite_vals.abs()) nonzero_finite_min = tensor_totype(nonzero_finite_abs.min()) nonzero_finite_max = tensor_totype(nonzero_finite_abs.max()) for value in nonzero_finite_vals: if value != torch.ceil(value): self.int_mode = False break if self.int_mode: # in int_mode for floats, all numbers are integers, and we append a decimal to nonfinites # to indicate that the tensor is of floating type. add 1 to the len to account for this. if ( nonzero_finite_max / nonzero_finite_min > 1000.0 or nonzero_finite_max > 1.0e8 ): self.sci_mode = True for value in nonzero_finite_vals: value_str = f"{{:.{PRINT_OPTS.precision}e}}".format(value) self.max_width = max(self.max_width, len(value_str)) else: for value in nonzero_finite_vals: value_str = f"{value:.0f}" self.max_width = max(self.max_width, len(value_str) + 1) else: # Check if scientific representation should be used. if ( nonzero_finite_max / nonzero_finite_min > 1000.0 or nonzero_finite_max > 1.0e8 or nonzero_finite_min < 1.0e-4 ): self.sci_mode = True for value in nonzero_finite_vals: value_str = f"{{:.{PRINT_OPTS.precision}e}}".format(value) self.max_width = max(self.max_width, len(value_str)) else: for value in nonzero_finite_vals: value_str = f"{{:.{PRINT_OPTS.precision}f}}".format(value) self.max_width = max(self.max_width, len(value_str)) if PRINT_OPTS.sci_mode is not None: self.sci_mode = PRINT_OPTS.sci_mode def width(self): return self.max_width def format(self, value): if self.floating_dtype: if self.sci_mode: ret = f"{{:{self.max_width}.{PRINT_OPTS.precision}e}}".format(value) elif self.int_mode: ret = f"{value:.0f}" if not (math.isinf(value) or math.isnan(value)): ret += "." else: ret = f"{{:.{PRINT_OPTS.precision}f}}".format(value) else: ret = f"{value}" return (self.max_width - len(ret)) * " " + ret def _scalar_str(self, formatter1, formatter2=None): if formatter2 is not None: real_str = _scalar_str(self.real, formatter1) imag_str = (_scalar_str(self.imag, formatter2) + "j").lstrip() # handles negative numbers, +0.0, -0.0 if imag_str[0] == "+" or imag_str[0] == "-": return real_str + imag_str else: return real_str + "+" + imag_str else: return formatter1.format(self.item()) def _vector_str(self, indent, summarize, formatter1, formatter2=None): # length includes spaces and comma between elements element_length = formatter1.width() + 2 if formatter2 is not None: # width for imag_formatter + an extra j for complex element_length += formatter2.width() + 1 elements_per_line = max( 1, int(math.floor((PRINT_OPTS.linewidth - indent) / (element_length))) ) def _val_formatter(val, formatter1=formatter1, formatter2=formatter2): if formatter2 is not None: real_str = formatter1.format(val.real) imag_str = (formatter2.format(val.imag) + "j").lstrip() # handles negative numbers, +0.0, -0.0 if imag_str[0] == "+" or imag_str[0] == "-": return real_str + imag_str else: return real_str + "+" + imag_str else: return formatter1.format(val) if summarize and not PRINT_OPTS.edgeitems: # Deal with edge case that negative zero is zero data = ["..."] elif summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems: data = ( [_val_formatter(val) for val in self[: PRINT_OPTS.edgeitems].tolist()] + [" ..."] + [_val_formatter(val) for val in self[-PRINT_OPTS.edgeitems :].tolist()] ) else: data = [_val_formatter(val) for val in self.tolist()] data_lines = [ data[i : i + elements_per_line] for i in range(0, len(data), elements_per_line) ] lines = [", ".join(line) for line in data_lines] return "[" + ("," + "\n" + " " * (indent + 1)).join(lines) + "]" # formatter2 is only used for printing complex tensors. # For complex tensors, formatter1 and formatter2 are the formatters for tensor.real # and tensor.imag respesectively def _tensor_str_with_formatter(self, indent, summarize, formatter1, formatter2=None): dim = self.dim() if dim == 0: return _scalar_str(self, formatter1, formatter2) if dim == 1: return _vector_str(self, indent, summarize, formatter1, formatter2) if summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems: slices = ( [ _tensor_str_with_formatter( self[i], indent + 1, summarize, formatter1, formatter2 ) for i in range(0, PRINT_OPTS.edgeitems) ] + ["..."] + [ _tensor_str_with_formatter( self[i], indent + 1, summarize, formatter1, formatter2 ) for i in range(len(self) - PRINT_OPTS.edgeitems, len(self)) ] ) else: slices = [ _tensor_str_with_formatter( self[i], indent + 1, summarize, formatter1, formatter2 ) for i in range(0, self.size(0)) ] tensor_str = ("," + "\n" * (dim - 1) + " " * (indent + 1)).join(slices) return "[" + tensor_str + "]" def _tensor_str(self, indent): if self.numel() == 0: return "[]" if self.has_names(): # There are two main codepaths (possibly more) that tensor printing goes through: # - tensor data can fit comfortably on screen # - tensor data needs to be summarized # Some of the codepaths don't fully support named tensors, so we send in # an unnamed tensor to the formatting code as a workaround. self = self.rename(None) summarize = self.numel() > PRINT_OPTS.threshold if self._is_zerotensor(): self = self.clone() # handle the negative bit if self.is_neg(): self = self.resolve_neg() if self.dtype in [ torch.float16, torch.bfloat16, torch.float8_e5m2, torch.float8_e5m2fnuz, torch.float8_e4m3fn, torch.float8_e4m3fnuz, ]: self = self.float() if self.dtype is torch.complex32: self = self.cfloat() if self.dtype.is_complex: # handle the conjugate bit self = self.resolve_conj() real_formatter = _Formatter( get_summarized_data(self.real) if summarize else self.real ) imag_formatter = _Formatter( get_summarized_data(self.imag) if summarize else self.imag ) return _tensor_str_with_formatter( self, indent, summarize, real_formatter, imag_formatter ) else: formatter = _Formatter(get_summarized_data(self) if summarize else self) return _tensor_str_with_formatter(self, indent, summarize, formatter) def _add_suffixes(tensor_str, suffixes, indent, force_newline): tensor_strs = [tensor_str] last_line_len = len(tensor_str) - tensor_str.rfind("\n") + 1 for suffix in suffixes: suffix_len = len(suffix) if force_newline or last_line_len + suffix_len + 2 > PRINT_OPTS.linewidth: tensor_strs.append(",\n" + " " * indent + suffix) last_line_len = indent + suffix_len force_newline = False else: tensor_strs.append(", " + suffix) last_line_len += suffix_len + 2 tensor_strs.append(")") return "".join(tensor_strs) def get_summarized_data(self): dim = self.dim() if dim == 0: return self if dim == 1: if self.size(0) > 2 * PRINT_OPTS.edgeitems: return torch.cat( (self[: PRINT_OPTS.edgeitems], self[-PRINT_OPTS.edgeitems :]) ) else: return self if not PRINT_OPTS.edgeitems: return self.new_empty([0] * self.dim()) elif self.size(0) > 2 * PRINT_OPTS.edgeitems: start = [self[i] for i in range(0, PRINT_OPTS.edgeitems)] end = [self[i] for i in range(len(self) - PRINT_OPTS.edgeitems, len(self))] return torch.stack([get_summarized_data(x) for x in (start + end)]) else: return torch.stack([get_summarized_data(x) for x in self]) def _str_intern(inp, *, tensor_contents=None): if torch._C._functorch.is_functorch_wrapped_tensor(inp): return _functorch_wrapper_str_intern(inp, tensor_contents=tensor_contents) is_plain_tensor = type(inp) is torch.Tensor or type(inp) is torch.nn.Parameter if inp.is_nested: prefix = "nested_tensor(" elif is_plain_tensor: prefix = "tensor(" else: prefix = f"{type(inp).__name__}(" indent = len(prefix) suffixes = [] custom_contents_provided = tensor_contents is not None if custom_contents_provided: tensor_str = tensor_contents # This is used to extract the primal value and thus disable the forward AD # within this function. # TODO(albanD) This needs to be updated when more than one level is supported self, tangent = torch.autograd.forward_ad.unpack_dual(inp) # Note [Print tensor device]: # A general logic here is we only print device when it doesn't match # the device specified in default tensor type. # Currently torch.set_default_tensor_type() only supports CPU/CUDA, thus # torch._C._get_default_device() only returns either cpu or cuda. # In other cases, we don't have a way to set them as default yet, # and we should always print out device for them. if ( self.device.type != torch._C._get_default_device() or ( self.device.type == "cuda" and torch.cuda.current_device() != self.device.index ) or (self.device.type == "mps") ): suffixes.append("device='" + str(self.device) + "'") # Tensor printing performs tensor operations like slice, indexing, etc to make it in a # representable format. These operations on ipu/xla/lazy/mtia tensor results in compilations. Hence, # to avoid compilations, copying the tensor to cpu before printing. if self.device.type in ["xla", "lazy", "ipu", "mtia"]: self = self.to("cpu") # TODO: add an API to map real -> complex dtypes _default_complex_dtype = ( torch.cdouble if torch.get_default_dtype() == torch.double else torch.cfloat ) has_default_dtype = self.dtype in ( torch.get_default_dtype(), _default_complex_dtype, torch.int64, torch.bool, ) if self.is_sparse: suffixes.append("size=" + str(tuple(self.shape))) from torch._subclasses.fake_tensor import FakeTensor is_meta = self.is_meta or isinstance(self, FakeTensor) if not is_meta: suffixes.append("nnz=" + str(self._nnz())) if not has_default_dtype: suffixes.append("dtype=" + str(self.dtype)) if not custom_contents_provided: indices_prefix = "indices=tensor(" indices = self._indices().detach() if is_meta: indices_str = "..." else: indices_str = _tensor_str(indices, indent + len(indices_prefix)) if indices.numel() == 0 or is_meta: indices_str += ", size=" + str(tuple(indices.shape)) values_prefix = "values=tensor(" values = self._values().detach() if is_meta: values_str = "..." else: values_str = _tensor_str(values, indent + len(values_prefix)) if values.numel() == 0 or is_meta: values_str += ", size=" + str(tuple(values.shape)) tensor_str = ( indices_prefix + indices_str + "),\n" + " " * indent + values_prefix + values_str + ")" ) elif self.layout in { torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc, }: from torch._subclasses.fake_tensor import FakeTensor suffixes.append("size=" + str(tuple(self.shape))) is_meta = self.is_meta or isinstance(self, FakeTensor) if not is_meta: suffixes.append("nnz=" + str(self._nnz())) if not has_default_dtype: suffixes.append("dtype=" + str(self.dtype)) if not custom_contents_provided: compressed_indices_method, plain_indices_method = { torch.sparse_csr: (torch.Tensor.crow_indices, torch.Tensor.col_indices), torch.sparse_csc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices), torch.sparse_bsr: (torch.Tensor.crow_indices, torch.Tensor.col_indices), torch.sparse_bsc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices), }[self.layout] if self.layout in {torch.sparse_csr, torch.sparse_bsr}: cdimname, pdimname = "row", "column" else: cdimname, pdimname = "column", "row" compressed_indices_prefix = f"c{cdimname[:3]}_indices=tensor(" compressed_indices = compressed_indices_method(self).detach() if is_meta: compressed_indices_str = "..." else: compressed_indices_str = _tensor_str( compressed_indices, indent + len(compressed_indices_prefix) ) if compressed_indices.numel() == 0 or is_meta: compressed_indices_str += ", size=" + str( tuple(compressed_indices.shape) ) plain_indices_prefix = f"{pdimname[:3]}_indices=tensor(" plain_indices = plain_indices_method(self).detach() if is_meta: plain_indices_str = "..." else: plain_indices_str = _tensor_str( plain_indices, indent + len(plain_indices_prefix) ) if plain_indices.numel() == 0 or is_meta: plain_indices_str += ", size=" + str(tuple(plain_indices.shape)) values_prefix = "values=tensor(" values = self.values().detach() if is_meta: values_str = "..." else: values_str = _tensor_str(values, indent + len(values_prefix)) if values.numel() == 0 or is_meta: values_str += ", size=" + str(tuple(values.shape)) tensor_str = ( compressed_indices_prefix + compressed_indices_str + "),\n" + " " * indent + plain_indices_prefix + plain_indices_str + "),\n" + " " * indent + values_prefix + values_str + ")" ) elif self.is_quantized: suffixes.append("size=" + str(tuple(self.shape))) if not has_default_dtype: suffixes.append("dtype=" + str(self.dtype)) suffixes.append("quantization_scheme=" + str(self.qscheme())) if ( self.qscheme() == torch.per_tensor_affine or self.qscheme() == torch.per_tensor_symmetric ): suffixes.append("scale=" + str(self.q_scale())) suffixes.append("zero_point=" + str(self.q_zero_point())) elif ( self.qscheme() == torch.per_channel_affine or self.qscheme() == torch.per_channel_symmetric or self.qscheme() == torch.per_channel_affine_float_qparams ): suffixes.append("scale=" + str(self.q_per_channel_scales())) suffixes.append("zero_point=" + str(self.q_per_channel_zero_points())) suffixes.append("axis=" + str(self.q_per_channel_axis())) if not custom_contents_provided: tensor_str = _tensor_str(self.dequantize(), indent) elif self.is_nested: if not custom_contents_provided: def indented_str(s, indent): return "\n".join(f" {line}" for line in s.split("\n")) strs = ",\n".join( indented_str(str(t), indent + 1) for t in torch.ops.aten.unbind.int(self, 0) ) tensor_str = f"[\n{strs}\n]" elif torch._is_functional_tensor(self): prefix = "_to_functional_tensor(" tensor_str = repr(torch._from_functional_tensor(self)) else: # Circular import problem, so we import it here from torch._subclasses.fake_tensor import FakeTensor if self.is_meta or isinstance(self, FakeTensor): suffixes.append("size=" + str(tuple(self.shape))) if self.dtype != torch.get_default_dtype(): suffixes.append("dtype=" + str(self.dtype)) # TODO: This implies that ellipses is valid syntax for allocating # a meta tensor or FakeTensor, which it could be, but it isn't right now if not custom_contents_provided: tensor_str = "..." else: if self.numel() == 0 and not self.is_sparse: # Explicitly print the shape if it is not (0,), to match NumPy behavior if self.dim() != 1: suffixes.append("size=" + str(tuple(self.shape))) # In an empty tensor, there are no elements to infer if the dtype # should be int64, so it must be shown explicitly. if self.dtype != torch.get_default_dtype(): suffixes.append("dtype=" + str(self.dtype)) if not custom_contents_provided: tensor_str = "[]" else: if not PRINT_OPTS.edgeitems: suffixes.append("size=" + str(tuple(self.shape))) if not has_default_dtype: suffixes.append("dtype=" + str(self.dtype)) if not custom_contents_provided: if self.layout != torch.strided: tensor_str = _tensor_str(self.to_dense(), indent) else: tensor_str = _tensor_str(self, indent) if self.layout != torch.strided: suffixes.append("layout=" + str(self.layout)) # Use inp here to get the original grad_fn and not the one generated by the forward grad # unpacking. grad_fn_name = None try: grad_fn = inp.grad_fn except RuntimeError: # Accessing the grad_fn calls rebasing logic which would cause an error # if that tensor is a view created in no-grad mode modified in-place in # no-grad mode. See: https://github.com/pytorch/pytorch/issues/99968 grad_fn_name = "Invalid" if grad_fn_name is None and grad_fn is not None: # type: ignore[possibly-undefined] grad_fn_name = type(grad_fn).__name__ if grad_fn_name == "CppFunction": grad_fn_name = grad_fn.name().rsplit("::", 1)[-1] if grad_fn_name is not None: suffixes.append(f"grad_fn=<{grad_fn_name}>") elif inp.requires_grad: suffixes.append("requires_grad=True") if self.has_names(): suffixes.append(f"names={self.names}") if tangent is not None: suffixes.append(f"tangent={tangent}") string_repr = _add_suffixes( prefix + tensor_str, suffixes, indent, force_newline=self.is_sparse # type: ignore[possibly-undefined] ) # Check if this instance is flagged as a parameter and change the repr accordingly. # Unfortunately, this function has to be aware of this detail. # NB: This is currently skipped for plain tensor parameters to maintain BC. In the future, # this should be done for those as well to produce a valid repr. if isinstance(self, torch.nn.Parameter) and not is_plain_tensor: string_repr = f"Parameter({string_repr})" return string_repr def _functorch_wrapper_str_intern(tensor, *, tensor_contents=None): level = torch._C._functorch.maybe_get_level(tensor) assert level != -1 if torch._C._functorch.is_functionaltensor(tensor): # Since we're unwrapping the FunctionalTensorWrapper, we need to make sure # that it's up to date first torch._sync(tensor) value = torch._C._functorch.get_unwrapped(tensor) value_repr = repr(value) indented_value_repr = textwrap.indent(value_repr, " " * 4) if torch._C._functorch.is_batchedtensor(tensor): bdim = torch._C._functorch.maybe_get_bdim(tensor) assert bdim != -1 return ( f"BatchedTensor(lvl={level}, bdim={bdim}, value=\n" f"{indented_value_repr}\n" f")" ) if torch._C._functorch.is_gradtrackingtensor(tensor): return ( f"GradTrackingTensor(lvl={level}, value=\n" f"{indented_value_repr}\n" f")" ) if torch._C._functorch.is_functionaltensor(tensor): return f"FunctionalTensor(lvl={level}, value=\\\n{value_repr})" raise ValueError("We don't know how to print this, please file us an issue") def _str(self, *, tensor_contents=None): with torch.no_grad(), torch.utils._python_dispatch._disable_current_modes(): guard = torch._C._DisableFuncTorch() return _str_intern(self, tensor_contents=tensor_contents)