Source code for sigpy.plot

# -*- coding: utf-8 -*-
"""This module contains plotting functions based on matplotlib
for image, line, and scatter plots.

A feature of these plotting functions is that
they can be controlled using only hotkeys
so the user does not need to move away from the keyboard.

Given an array ``x``, an example usage is:

    >>> ImagePlot(x)
    >>> LinePlot(x)
    >>> ScatterPlot(x)

"""
import datetime
import os
import subprocess
import uuid

import numpy as np

import sigpy as sp

__all__ = ["ImagePlot", "LinePlot", "ScatterPlot"]


image_plot_help_str = r"""
$\bf{Hotkeys:}$
    $\bf{h:}$ show/hide hotkey menu.
    $\bf{x/y/z:}$ set current axis as x/y/z.
    $\bf{t:}$ swap between x and y.
    $\bf{c:}$ select current axis as color.
    $\bf{left/right:}$ change current axis.
    $\bf{up/down:}$ change slice along current axis.
    $\bf{a:}$ toggle hide all labels, titles and axes.
    $\bf{m/p/r/i/l:}$  magnitude/phase/real/imaginary/log mode.
    $\bf{[/]:}$ change brightness.
    $\bf{\{/\}:}$ change contrast.
    $\bf{s:}$ save as png.
    $\bf{g/v:}$ save as gif/video by along current axis.
    $\bf{q:}$ refresh.
    $\bf{0-9:}$ enter slice number.
    $\bf{enter:}$ set current axis as slice number.
"""


[docs]class ImagePlot(object): """Plot array as image. Press 'h' for a menu for hotkeys. Args: im (array): image numpy/cupy array. x (int): x axis. y (int): y axis. z (None or int): z axis. c (None or int): color axis. hide_axes (bool): toggle hiding axes, labels and title. mode (str): specify magnitude, phase, real, imaginary, and log mode. {'m', 'p', 'r', 'i', 'l'}. title (str): title. interpolation (str): plot interpolation. save_basename (str): saved png, gif, and video base name. fps (int): frame per seconds for gif and video. """
[docs] def __init__( self, im, x=-1, y=-2, z=None, c=None, hide_axes=False, mode=None, colormap=None, vmin=None, vmax=None, title="", interpolation="nearest", save_basename="Figure", fps=10, ): if im.ndim < 2: raise TypeError( "Image dimension must at least be two, got {im_ndim}".format( im_ndim=im.ndim ) ) import matplotlib.pyplot as plt self.axim = None self.im = im self.fig = plt.figure() self.ax = self.fig.add_subplot(111) self.shape = self.im.shape self.ndim = self.im.ndim self.slices = [s // 2 for s in self.shape] self.flips = [1] * self.ndim self.x = x % self.ndim self.y = y % self.ndim self.z = z % self.ndim if z is not None else None self.c = c % self.ndim if c is not None else None self.d = max(self.ndim - 3, 0) self.hide_axes = hide_axes self.show_help = False self.title = title self.interpolation = interpolation self.mode = mode self.colormap = colormap self.entering_slice = False self.vmin = vmin self.vmax = vmax self.save_basename = save_basename self.fps = fps self.help_text = None self.fig.canvas.mpl_disconnect( self.fig.canvas.manager.key_press_handler_id ) self.fig.canvas.mpl_connect("key_press_event", self.key_press) self.update_axes() self.update_image() self.fig.canvas.draw() plt.show()
def key_press(self, event): if event.key == "up": if self.d not in [self.x, self.y, self.z, self.c]: self.slices[self.d] = (self.slices[self.d] + 1) % self.shape[ self.d ] else: self.flips[self.d] *= -1 self.update_axes() self.update_image() self.fig.canvas.draw() elif event.key == "down": if self.d not in [self.x, self.y, self.z, self.c]: self.slices[self.d] = (self.slices[self.d] - 1) % self.shape[ self.d ] else: self.flips[self.d] *= -1 self.update_axes() self.update_image() self.fig.canvas.draw() elif event.key == "left": self.d = (self.d - 1) % self.ndim self.update_axes() self.fig.canvas.draw() elif event.key == "right": self.d = (self.d + 1) % self.ndim self.update_axes() self.fig.canvas.draw() elif event.key == "x" and self.d not in [self.x, self.z, self.c]: if self.d == self.y: self.x, self.y = self.y, self.x else: self.x = self.d self.update_axes() self.update_image() self.fig.canvas.draw() elif event.key == "y" and self.d not in [self.y, self.z, self.c]: if self.d == self.x: self.x, self.y = self.y, self.x else: self.y = self.d self.update_axes() self.update_image() self.fig.canvas.draw() elif event.key == "z" and self.d not in [self.x, self.y, self.c]: if self.d == self.z: self.z = None else: self.z = self.d self.update_axes() self.update_image() self.fig.canvas.draw() elif ( event.key == "c" and self.d not in [self.x, self.y, self.z] and self.shape[self.d] == 3 ): if self.d == self.c: self.c = None else: self.c = self.d self.update_axes() self.update_image() self.fig.canvas.draw() elif event.key == "t": self.x, self.y = self.y, self.x self.update_axes() self.update_image() self.fig.canvas.draw() elif event.key == "a": self.hide_axes = not self.hide_axes self.update_axes() self.fig.canvas.draw() elif event.key == "f": self.fig.canvas.manager.full_screen_toggle() elif event.key == "q": self.vmin = None self.vmax = None self.update_axes() self.update_image() self.fig.canvas.draw() elif event.key == "]": width = self.vmax - self.vmin self.vmin -= width * 0.1 self.vmax -= width * 0.1 self.update_image() self.fig.canvas.draw() elif event.key == "[": width = self.vmax - self.vmin self.vmin += width * 0.1 self.vmax += width * 0.1 self.update_image() self.fig.canvas.draw() elif event.key == "}": width = self.vmax - self.vmin center = (self.vmax + self.vmin) / 2 self.vmin = center - width * 1.1 / 2 self.vmax = center + width * 1.1 / 2 self.update_image() self.fig.canvas.draw() elif event.key == "{": width = self.vmax - self.vmin center = (self.vmax + self.vmin) / 2 self.vmin = center - width * 0.9 / 2 self.vmax = center + width * 0.9 / 2 self.update_image() self.fig.canvas.draw() elif event.key in ["m", "p", "r", "i", "l"]: self.vmin = None self.vmax = None self.mode = event.key self.update_axes() self.update_image() self.fig.canvas.draw() elif event.key == "s": filename = self.save_basename + datetime.datetime.now().strftime( " %Y-%m-%d at %I.%M.%S %p.png" ) self.fig.savefig( filename, transparent=True, format="png", bbox_inches="tight", pad_inches=0, ) elif event.key == "g": filename = self.save_basename + datetime.datetime.now().strftime( " %Y-%m-%d at %I.%M.%S %p.gif" ) temp_basename = uuid.uuid4() bbox = self.fig.get_tightbbox(self.fig.canvas.get_renderer()) for i in range(self.shape[self.d]): self.slices[self.d] = i self.update_axes() self.update_image() self.fig.canvas.draw() self.fig.savefig( "{} {:05d}.png".format(temp_basename, i), format="png", bbox_inches=bbox, pad_inches=0, ) subprocess.run( [ "ffmpeg", "-f", "image2", "-s", "{}x{}".format( int(bbox.width * self.fig.dpi), int(bbox.height * self.fig.dpi), ), "-framerate", str(self.fps), "-i", "{} %05d.png".format(temp_basename), "-vf", "palettegen", "{} palette.png".format(temp_basename), ] ) subprocess.run( [ "ffmpeg", "-f", "image2", "-s", "{}x{}".format( int(bbox.width * self.fig.dpi), int(bbox.height * self.fig.dpi), ), "-framerate", str(self.fps), "-i", "{} %05d.png".format(temp_basename), "-i", "{} palette.png".format(temp_basename), "-lavfi", "paletteuse", filename, ] ) os.remove("{} palette.png".format(temp_basename)) for i in range(self.shape[self.d]): os.remove("{} {:05d}.png".format(temp_basename, i)) elif event.key == "v": filename = self.save_basename + datetime.datetime.now().strftime( " %Y-%m-%d at %I.%M.%S %p.mp4" ) temp_basename = uuid.uuid4() for i in range(self.shape[self.d]): self.slices[self.d] = i self.update_axes() self.update_image() self.fig.canvas.draw() self.fig.savefig( "{} {:05d}.png".format(temp_basename, i), format="png", transparent=True, bbox_inches="tight", pad_inches=0, ) subprocess.run( [ "ffmpeg", "-r", str(self.fps), "-i", "{} %05d.png".format(temp_basename), "-vf", "crop=floor(iw/2)*2-10:floor(ih/2)*2-10", "-pix_fmt", "yuv420p", "-crf", "1", "-vcodec", "libx264", "-preset", "veryslow", filename, ] ) for i in range(self.shape[self.d]): os.remove("{} {:05d}.png".format(temp_basename, i)) elif event.key in [ "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "backspace", ] and self.d not in [self.x, self.y, self.z, self.c]: if self.entering_slice: if event.key == "backspace": if self.entered_slice < 10: self.entering_slice = False else: self.entered_slice //= 10 else: self.entered_slice = self.entered_slice * 10 + int( event.key ) elif event.key != "backspace": self.entering_slice = True self.entered_slice = int(event.key) self.update_axes() self.fig.canvas.draw() elif event.key == "enter" and self.entering_slice: self.entering_slice = False if self.entered_slice < self.shape[self.d]: self.slices[self.d] = self.entered_slice self.update_image() self.update_axes() self.fig.canvas.draw() elif event.key == "h": self.show_help = not self.show_help self.update_image() self.fig.canvas.draw() else: return def update_image(self): # Extract slice. idx = [] for i in range(self.ndim): if i in [self.x, self.y, self.z, self.c]: idx.append(slice(None, None, self.flips[i])) else: idx.append(self.slices[i]) idx = tuple(idx) imv = sp.to_device(self.im[idx]) # Transpose to have [z, y, x, c]. imv_dims = [self.y, self.x] if self.z is not None: imv_dims = [self.z] + imv_dims if self.c is not None: imv_dims = imv_dims + [self.c] imv = np.transpose(imv, np.argsort(np.argsort(imv_dims))) imv = array_to_image(imv, color=self.c is not None) if self.mode is None: if np.isrealobj(imv): self.mode = "r" else: self.mode = "m" if self.mode == "m": imv = np.abs(imv) elif self.mode == "p": imv = np.angle(imv) elif self.mode == "r": imv = np.real(imv) elif self.mode == "i": imv = np.imag(imv) elif self.mode == "l": imv = np.abs(imv) imv = np.log(imv, out=np.ones_like(imv) * -31, where=imv != 0) if self.vmin is None: self.vmin = imv.min() if self.vmax is None: self.vmax = imv.max() if self.axim is None: if self.colormap is None: colormap = "gray" else: colormap = self.colormap self.axim = self.ax.imshow( imv, vmin=self.vmin, vmax=self.vmax, cmap=colormap, origin="lower", interpolation=self.interpolation, aspect=1.0, extent=[0, imv.shape[1], 0, imv.shape[0]], ) if self.colormap is not None: self.fig.colorbar(self.axim) else: self.axim.set_data(imv) self.axim.set_extent([0, imv.shape[1], 0, imv.shape[0]]) self.axim.set_clim(self.vmin, self.vmax) if self.help_text is None: bbox_props = dict( boxstyle="round", pad=1, fc="white", alpha=0.95, lw=0 ) self.help_text = self.ax.text( imv.shape[0] / 2, imv.shape[1] / 2, image_plot_help_str, ha="center", va="center", linespacing=1.5, ma="left", size=8, bbox=bbox_props, ) self.help_text.set_visible(self.show_help) def update_axes(self): if not self.hide_axes: caption = "[" for i in range(self.ndim): if i == self.d: caption += "[" else: caption += " " if self.flips[i] == -1 and ( i == self.x or i == self.y or i == self.z or i == self.c ): caption += "-" if i == self.x: caption += "x" elif i == self.y: caption += "y" elif i == self.z: caption += "z" elif i == self.c: caption += "c" elif i == self.d and self.entering_slice: caption += str(self.entered_slice) + "_" else: caption += str(self.slices[i]) if i == self.d: caption += "]" else: caption += " " caption += "]" self.ax.set_title(caption) self.fig.suptitle(self.title) self.ax.xaxis.set_visible(True) self.ax.yaxis.set_visible(True) self.ax.title.set_visible(True) else: self.ax.set_title("") self.fig.suptitle("") self.ax.xaxis.set_visible(False) self.ax.yaxis.set_visible(False) self.ax.title.set_visible(False)
def mosaic_shape(batch): mshape = [int(batch**0.5), batch // int(batch**0.5)] while sp.prod(mshape) < batch: mshape[1] += 1 if (mshape[0] - 1) * (mshape[1] + 1) == batch: mshape[0] -= 1 mshape[1] += 1 return tuple(mshape) def array_to_image(arr, color=False): """ Flattens all dimensions except the last two Args: arr (array): shape [z, x, y, c] if color, else [z, x, y] """ if color and not (arr.max() == 0 and arr.min() == 0): arr = arr / np.abs(arr).max() if arr.ndim == 2: return arr elif color and arr.ndim == 3: return arr if color: img_shape = arr.shape[-3:] batch = sp.prod(arr.shape[:-3]) mshape = mosaic_shape(batch) else: img_shape = arr.shape[-2:] batch = sp.prod(arr.shape[:-2]) mshape = mosaic_shape(batch) if sp.prod(mshape) == batch: img = arr.reshape((batch,) + img_shape) else: img = np.zeros((sp.prod(mshape),) + img_shape, dtype=arr.dtype) img[:batch, ...] = arr.reshape((batch,) + img_shape) img = img.reshape(mshape + img_shape) if color: img = np.transpose(img, (0, 2, 1, 3, 4)) img = img.reshape( (img_shape[0] * mshape[0], img_shape[1] * mshape[1], 3) ) else: img = np.transpose(img, (0, 2, 1, 3)) img = img.reshape((img_shape[0] * mshape[0], img_shape[1] * mshape[1])) return img
[docs]class LinePlot(object): """Plot array as lines. Keyword Args: x: select current dimension as x left/right: increment/decrement current dimension up/down: flip axis when current dimension is x or y otherwise increment/decrement slice at current dimension h: toggle hide all labels, titles and axes m: magnitude mode p: phase mode r: real mode i: imaginary mode l: log mode s: save as png. g: save as gif by traversing current dimension. v: save as video by traversing current dimension. """
[docs] def __init__( self, arr, x=-1, hide_axes=False, mode="m", title="", save_basename="Figure", fps=10, ): import matplotlib.pyplot as plt self.arr = arr self.axarr = None self.fig = plt.figure() self.ax = self.fig.add_subplot(111) self.shape = self.arr.shape self.ndim = self.arr.ndim self.slices = [s // 2 for s in self.shape] self.flips = [1] * self.ndim self.x = x % self.ndim self.d = max(self.ndim - 3, 0) self.hide_axes = hide_axes self.title = title self.mode = mode self.save_basename = save_basename self.fps = fps self.bottom = None self.top = None self.fig.canvas.mpl_disconnect( self.fig.canvas.manager.key_press_handler_id ) self.fig.canvas.mpl_connect("key_press_event", self.key_press) self.update_axes() self.update_line() self.fig.canvas.draw() plt.show()
def key_press(self, event): if event.key == "up": if self.d != self.x: self.slices[self.d] = (self.slices[self.d] + 1) % self.shape[ self.d ] else: self.flips[self.d] *= -1 self.update_axes() self.update_line() self.fig.canvas.draw() elif event.key == "down": if self.d != self.x: self.slices[self.d] = (self.slices[self.d] - 1) % self.shape[ self.d ] else: self.flips[self.d] *= -1 self.update_axes() self.update_line() self.fig.canvas.draw() elif event.key == "left": self.d = (self.d - 1) % self.ndim self.update_axes() self.fig.canvas.draw() elif event.key == "right": self.d = (self.d + 1) % self.ndim self.update_axes() self.fig.canvas.draw() elif event.key == "x" and self.d != self.x: self.x = self.d self.update_axes() self.update_line() self.fig.canvas.draw() elif event.key == "a": self.hide_axes = not self.hide_axes self.update_axes() self.fig.canvas.draw() elif event.key == "f": self.fig.canvas.manager.full_screen_toggle() elif ( event.key == "m" or event.key == "p" or event.key == "r" or event.key == "i" or event.key == "l" ): self.mode = event.key self.bottom = None self.top = None self.update_axes() self.update_line() self.fig.canvas.draw() elif event.key == "s": filename = self.save_basename + datetime.datetime.now().strftime( " %Y-%m-%d at %I.%M.%S %p.png" ) self.fig.savefig( filename, transparent=True, format="png", bbox_inches="tight", pad_inches=0, ) elif event.key == "g": filename = self.save_basename + datetime.datetime.now().strftime( " %Y-%m-%d at %I.%M.%S %p.gif" ) temp_basename = uuid.uuid4() bbox = self.fig.get_tightbbox(self.fig.canvas.get_renderer()) for i in range(self.shape[self.d]): self.slices[self.d] = i self.update_axes() self.update_line() self.fig.canvas.draw() self.fig.savefig( "{} {:05d}.png".format(temp_basename, i), format="png", bbox_inches=bbox, pad_inches=0, ) subprocess.run( [ "ffmpeg", "-f", "image2", "-s", "{}x{}".format( int(bbox.width * self.fig.dpi), int(bbox.height * self.fig.dpi), ), "-framerate", str(self.fps), "-i", "{} %05d.png".format(temp_basename), "-vf", "palettegen", "{} palette.png".format(temp_basename), ] ) subprocess.run( [ "ffmpeg", "-f", "image2", "-s", "{}x{}".format( int(bbox.width * self.fig.dpi), int(bbox.height * self.fig.dpi), ), "-framerate", str(self.fps), "-i", "{} %05d.png".format(temp_basename), "-i", "{} palette.png".format(temp_basename), "-lavfi", "paletteuse", filename, ] ) os.remove("{} palette.png".format(temp_basename)) for i in range(self.shape[self.d]): os.remove("{} {:05d}.png".format(temp_basename, i)) elif event.key == "v": filename = self.save_basename + datetime.datetime.now().strftime( " %Y-%m-%d at %h.%M.%S %p.mov" ) temp_basename = uuid.uuid4() bbox = self.fig.get_tightbbox(self.fig.canvas.get_renderer()) for i in range(self.shape[self.d]): self.slices[self.d] = i self.update_axes() self.update_line() self.fig.canvas.draw() self.fig.savefig( "{} {:05d}.png".format(temp_basename, i), format="png", bbox_inches=bbox, pad_inches=0, ) subprocess.run( [ "ffmpeg", "-framerate", str(self.fps), "-i", "{} %05d.png".format(temp_basename), "-vcodec", "png", filename, ] ) for i in range(self.shape[self.d]): os.remove("{} {:05d}.png".format(temp_basename, i)) else: return return def update_line(self): order = [i for i in range(self.ndim) if i != self.x] + [self.x] idx = tuple( [self.slices[i] for i in order[:-1]] + [slice(None, None, self.flips[self.x])] ) arrv = self.arr.transpose(order)[idx] if self.mode == "m": arrv = np.abs(arrv) elif self.mode == "p": arrv = np.angle(arrv) elif self.mode == "r": arrv = np.real(arrv) elif self.mode == "i": arrv = np.imag(arrv) elif self.mode == "l": eps = 1e-31 arrv = np.log(np.abs(arrv) + eps) if self.bottom is None: self.bottom = arrv.min() if self.top is None: self.top = arrv.max() if self.axarr is None: self.axarr = self.ax.plot(arrv)[0] else: self.axarr.set_xdata(np.arange(len(arrv))) self.axarr.set_ydata(arrv) self.ax.set_ylim(self.bottom, self.top) def update_axes(self): if not self.hide_axes: caption = "Slice: [" for i in range(self.ndim): if i == self.d: caption += "[" else: caption += " " if self.flips[i] == -1 and i == self.x: caption += "-" if i == self.x: caption += "x" else: caption += str(self.slices[i]) if i == self.d: caption += "]" else: caption += " " caption += "]" self.ax.set_title(caption) self.ax.axis("on") self.fig.suptitle(self.title) self.ax.xaxis.set_visible(True) self.ax.yaxis.set_visible(True) self.ax.title.set_visible(True) else: self.ax.set_title("") self.fig.suptitle("") self.ax.xaxis.set_visible(False) self.ax.yaxis.set_visible(False) self.ax.title.set_visible(False)
[docs]class ScatterPlot(object): """Plot array as scatter. Keyword Args: z: toggle current dimension as z dimension left/right: increment/decrement current dimension up/down: flip axis when current dimension is x or y otherwise increment/decrement slice at current dimension h: toggle hide all labels, titles and axes m: magnitude mode p: phase mode r: real mode i: imaginary mode l: log mode """
[docs] def __init__( self, coord, data=None, z=None, hide_axes=False, mode="m", title="", save_basename="Figure", fps=10, ): import matplotlib.pyplot as plt self.coord = coord assert coord.shape[-1] == 2 if data is None: self.data = np.ones(coord.shape[:-1]) else: self.data = data self.fig = plt.figure() self.ax = self.fig.add_subplot(111) self.ax.set_facecolor("k") self.ax.axis("equal") for c, d in zip(coord.shape[:-1], self.data.shape[-coord.ndim + 1 :]): assert c == d self.ndim = self.data.ndim - self.coord.ndim + 1 self.shape = self.data.shape[: self.ndim] self.slices = [s // 2 for s in self.shape] self.flips = [1] * self.ndim self.z = z % self.ndim if z is not None else None self.d = 0 self.hide_axes = hide_axes self.title = title self.mode = mode self.axsc = None self.entering_slice = False self.save_basename = save_basename self.fps = fps self.vmin = None self.vmax = None self.fig.canvas.mpl_disconnect( self.fig.canvas.manager.key_press_handler_id ) self.fig.canvas.mpl_connect("key_press_event", self.key_press) self.update_axes() self.update_data() self.fig.canvas.draw() plt.show()
def key_press(self, event): if event.key == "up": if self.d != self.z: self.slices[self.d] = (self.slices[self.d] + 1) % self.shape[ self.d ] else: self.flips[self.d] *= -1 self.update_axes() self.update_data() self.fig.canvas.draw() elif event.key == "down": if self.d != self.z: self.slices[self.d] = (self.slices[self.d] - 1) % self.shape[ self.d ] else: self.flips[self.d] *= -1 self.update_axes() self.update_data() self.fig.canvas.draw() elif event.key == "left": self.d = (self.d - 1) % self.ndim self.update_axes() self.fig.canvas.draw() elif event.key == "right": self.d = (self.d + 1) % self.ndim self.update_axes() self.fig.canvas.draw() # elif event.key == 'z': # if self.d == self.z: # self.z = None # else: # self.z = self.d # self.update_axes() # self.update_data() # self.fig.canvas.draw() elif event.key == "a": self.hide_axes = not self.hide_axes self.update_axes() self.fig.canvas.draw() elif event.key == "f": self.fig.canvas.manager.full_screen_toggle() elif ( event.key == "m" or event.key == "p" or event.key == "r" or event.key == "i" or event.key == "l" ): self.mode = event.key self.vmin = None self.vmax = None self.update_axes() self.update_data() self.fig.canvas.draw() elif event.key == "s": filename = self.save_basename + datetime.datetime.now().strftime( " %Y-%m-%d at %I.%M.%S %p.png" ) self.fig.savefig( filename, transparent=True, format="png", bbox_inches="tight", pad_inches=0, ) elif event.key == "g": filename = self.save_basename + datetime.datetime.now().strftime( " %Y-%m-%d at %I.%M.%S %p.gif" ) temp_basename = uuid.uuid4() bbox = self.fig.get_tightbbox(self.fig.canvas.get_renderer()) for i in range(self.shape[self.d]): self.slices[self.d] = i self.update_axes() self.update_data() self.fig.canvas.draw() self.fig.savefig( "{} {:05d}.png".format(temp_basename, i), format="png", bbox_inches=bbox, pad_inches=0, ) subprocess.run( [ "ffmpeg", "-f", "image2", "-s", "{}x{}".format( int(bbox.width * self.fig.dpi), int(bbox.height * self.fig.dpi), ), "-framerate", str(self.fps), "-i", "{} %05d.png".format(temp_basename), "-vf", "palettegen", "{} palette.png".format(temp_basename), ] ) subprocess.run( [ "ffmpeg", "-f", "image2", "-s", "{}x{}".format( int(bbox.width * self.fig.dpi), int(bbox.height * self.fig.dpi), ), "-framerate", str(self.fps), "-i", "{} %05d.png".format(temp_basename), "-i", "{} palette.png".format(temp_basename), "-lavfi", "paletteuse", filename, ] ) os.remove("{} palette.png".format(temp_basename)) for i in range(self.shape[self.d]): os.remove("{} {:05d}.png".format(temp_basename, i)) elif event.key == "v": filename = self.save_basename + datetime.datetime.now().strftime( " %Y-%m-%d at %h.%M.%S %p.mov" ) temp_basename = uuid.uuid4() bbox = self.fig.get_tightbbox(self.fig.canvas.get_renderer()) for i in range(self.shape[self.d]): self.slices[self.d] = i self.update_axes() self.update_data() self.fig.canvas.draw() self.fig.savefig( "{} {:05d}.png".format(temp_basename, i), format="png", bbox_inches=bbox, pad_inches=0, ) subprocess.run( [ "ffmpeg", "-framerate", str(self.fps), "-i", "{} %05d.png".format(temp_basename), "-vcodec", "png", filename, ] ) for i in range(self.shape[self.d]): os.remove("{} {:05d}.png".format(temp_basename, i)) elif ( event.key in ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "backspace"] and self.d != self.z ): if self.entering_slice: if event.key == "backspace": if self.entered_slice < 10: self.entering_slice = False else: self.entered_slice //= 10 else: self.entered_slice = self.entered_slice * 10 + int( event.key ) else: self.entering_slice = True self.entered_slice = int(event.key) self.update_axes() self.fig.canvas.draw() elif event.key == "enter" and self.entering_slice: self.entering_slice = False if self.entered_slice < self.shape[self.d]: self.slices[self.d] = self.entered_slice self.update_data() self.update_axes() self.fig.canvas.draw() else: return def update_data(self): idx = [] for i in range(self.ndim): if i == self.z: idx.append(slice(None, None, self.flips[i])) else: idx.append(self.slices[i]) idx = tuple(idx) if idx: datav = sp.to_device(self.data[idx]) else: datav = sp.to_device(self.data) # if self.z is not None: # datav_dims = [self.z] + datav_dims coordv = sp.to_device(self.coord) if self.mode == "m": datav = np.abs(datav) elif self.mode == "p": datav = np.angle(datav) elif self.mode == "r": datav = np.real(datav) elif self.mode == "i": datav = np.imag(datav) elif self.mode == "l": eps = 1e-31 datav = np.log(np.abs(datav) + eps) datav = datav.ravel() if self.vmin is None: if datav.min() == datav.max(): self.vmin = 0 else: self.vmin = datav.min() if self.vmax is None: self.vmax = datav.max() if self.axsc is None: self.axsc = self.ax.scatter( coordv[..., 0].ravel(), coordv[..., 1].ravel(), c=datav, s=1, linewidths=0, cmap="gray", vmin=self.vmin, vmax=self.vmax, ) else: self.axsc.set_offsets(coordv.T.reshape([-1, 2])) self.axsc.set_color(datav) def update_axes(self): if not self.hide_axes: caption = "[" for i in range(self.ndim): if i == self.d: caption += "[" else: caption += " " if self.flips[i] == -1 and i == self.z: caption += "-" if i == self.z: caption += "z" elif i == self.d and self.entering_slice: caption += str(self.entered_slice) + "_" else: caption += str(self.slices[i]) if i == self.d: caption += "]" else: caption += " " caption += "]" self.ax.set_title(caption) self.fig.suptitle(self.title) self.ax.xaxis.set_visible(True) self.ax.yaxis.set_visible(True) self.ax.title.set_visible(True) else: self.ax.set_title("") self.fig.suptitle("") self.ax.xaxis.set_visible(False) self.ax.yaxis.set_visible(False) self.ax.title.set_visible(False)