import abc
from pathlib import Path
import matplotlib as mpl
from matplotlib import pyplot, colors, cm
from mpl_toolkits.mplot3d import art3d
import numpy as np
from hikari.utility import gnuplot_map_palette, make_abspath, \
mpl_map_palette, sph2cart
from hikari.resources import gnuplot_angular_heatmap_template
[docs]
class ArtistError(Exception):
"""Exception raised when a problem with plotting in hikari occurs."""
def __init__(self, message):
super().__init__(message)
[docs]
class ArtistFactory:
"""A factory method for creating artists."""
def __init__(self):
self._artists = {}
[docs]
def register(self, name, artist):
self._artists[name] = artist
[docs]
def create(self, name, **kwargs):
artist = self._artists.get(name)
if not artist:
raise ValueError(f'Artist called "{name}" has not been registered!')
return artist(**kwargs)
[docs]
class Artist:
"""Base class used for plotting matplotlib and gnuplot plots"""
[docs]
@staticmethod
def _assert_is_iterable(iterable, length=0):
try:
_ = len(iterable)
if not len(iterable) == length and length > 0:
raise TypeError()
except TypeError:
raise ArtistError(f'object {iterable} should be an iterable'
f' of length {length}' if length else '')
[docs]
@abc.abstractmethod
def plot(self, path):
pass
[docs]
class AngularHeatmapArtist(Artist, abc.ABC):
"""Base class used for plotting angular heatmaps"""
HEAT_EXTENSION = '.lst'
HISTOGRAM_EXTENSION = '.his'
def __init__(self):
self.histogram = False
self._x_axis = ()
self._y_axis = ()
self._z_axis = ()
self._focus = ()
self._heat_limits = ()
self._polar_limits = ()
self._azimuth_limits = ()
@property
def x_axis(self):
return self._x_axis
@x_axis.setter
def x_axis(self, a):
self._assert_is_iterable(a, 3)
self._x_axis = tuple(a)
@property
def y_axis(self):
return self._y_axis
@y_axis.setter
def y_axis(self, a):
self._assert_is_iterable(a, 3)
self._y_axis = tuple(a)
@property
def z_axis(self):
return self._z_axis
@z_axis.setter
def z_axis(self, a):
self._assert_is_iterable(a, 3)
self._z_axis = tuple(a)
@property
def heat_limits(self):
return self._heat_limits
@heat_limits.setter
def heat_limits(self, lims):
self._assert_is_iterable(lims, 2)
self._heat_limits = tuple(lims)
@property
def focus(self):
return self._focus
@focus.setter
def focus(self, coords):
self._assert_is_iterable(coords)
[self._assert_is_iterable(c, 3) for c in coords]
self._focus = tuple(coords)
@property
def polar_limits(self):
return self._polar_limits
@polar_limits.setter
def polar_limits(self, lims):
self._assert_is_iterable(lims, 2)
self._polar_limits = tuple(lims)
@property
def azimuth_limits(self):
return self._azimuth_limits
@azimuth_limits.setter
def azimuth_limits(self, lims):
self._assert_is_iterable(lims, 2)
self._azimuth_limits = tuple(lims)
[docs]
class GnuplotArtist(abc.ABC):
"""Base class used for plotting gnuplot plots"""
GNUPLOT_EXTENSION = '.gnu'
def __init__(self):
super().__init__()
self._heat_palette = gnuplot_map_palette['']
@property
def heat_palette(self):
return self._heat_palette
@heat_palette.setter
def heat_palette(self, name):
self._heat_palette = gnuplot_map_palette[name]
[docs]
class MatplotlibArtist(abc.ABC):
"""Base class used for plotting matplotlib plots"""
def __init__(self):
super().__init__()
self._heat_palette = colors.LinearSegmentedColormap.from_list(
'heatmap', mpl_map_palette[''], N=256)
@property
def heat_palette(self):
return self._heat_palette
@heat_palette.setter
def heat_palette(self, name):
self._heat_palette = colors.LinearSegmentedColormap.from_list(
'heatmap', mpl_map_palette[name], N=256)
[docs]
class GnuplotAngularHeatmapArtist(GnuplotArtist, AngularHeatmapArtist):
"""Base class used for plotting gnuplot angular heatmaps"""
template = gnuplot_angular_heatmap_template
@property
def focus_string(self):
label = "set label at {}, {}, {} '' point ls 10 front"
return '\n'.join([label.format(*f) for f in self.focus])
[docs]
def plot(self, path):
png_path = Path(make_abspath(path))
directory, stem, ext = png_path.parent, png_path.stem, png_path.suffix
gnu_name = png_path.stem + self.GNUPLOT_EXTENSION
gnu_path = Path().joinpath(directory, gnu_name)
s = gnuplot_angular_heatmap_template.format(
axis_x1=self.x_axis[0], axis_x2=self.x_axis[1],
axis_x3=self.x_axis[2], axis_y1=self.y_axis[0],
axis_y2=self.y_axis[1], axis_y3=self.y_axis[2],
axis_z1=self.z_axis[0], axis_z2=self.z_axis[1],
axis_z3=self.z_axis[2],
cplt_min=self.heat_limits[0],
cplt_max=self.heat_limits[1],
histogram=int(self.histogram),
focus_string=self.focus_string,
job_name=stem,
min_ph=self.azimuth_limits[0],
max_ph=self.azimuth_limits[1],
min_th=self.polar_limits[0],
max_th=self.polar_limits[1],
palette=self.heat_palette)
with open(gnu_path, 'w+') as f:
f.write(s)
try:
from os import system, getcwd
system('cd ' + str(directory) + '; gnuplot ' + gnu_name)
except OSError:
raise ArtistError(f'OSError passed: Cannot plot {gnu_name}')
[docs]
class MatplotlibAngularHeatmapArtist(MatplotlibArtist, AngularHeatmapArtist):
"""Base class used for plotting matplotlib angular heatmaps"""
MESH_EXTENSION = '.dat'
[docs]
def plot(self, path):
# OS and I/O operations
png_path = Path(make_abspath(path))
directory, stem, ext = png_path.parent, png_path.stem, png_path.suffix
mesh_name = png_path.stem + self.MESH_EXTENSION
mesh_path = Path().joinpath(directory, mesh_name)
heat_mesh = np.loadtxt(str(mesh_path))
# set-up the plot
fig = pyplot.figure(figsize=(8, 6))
ax = fig.add_subplot(111, projection='3d')
pyplot.rcParams.update({'font.size': 16})
ax.view_init(elev=90 - sum(self.polar_limits) / 2,
azim=sum(self.azimuth_limits) / 2)
mpl_version = [int(i) for i in mpl.__version__.split('.')]
if mpl_version >= [3, 6]: # introduced in mpl 3.3.*
ax.set_box_aspect(None, zoom=1.4)
else: # deprecated mpl 3.6.*, privatized 3.8.*; similar, not identical
ax.dist = 6.5
ax.plot([1], [1], [1], 'w')
ax.plot([-1], [-1], [-1], 'w')
ax.set_axis_off()
# prepare surface in cartesian coordinates
polar_range = np.linspace(start=self.polar_limits[0],
stop=self.polar_limits[1],
num=heat_mesh.shape[1]) # 1=angle_res
azimuth_range = np.linspace(start=self.azimuth_limits[0],
stop=self.azimuth_limits[1],
num=heat_mesh.shape[0])
polar_mesh, azimuth_mesh = np.meshgrid(polar_range, azimuth_range)
x_mesh, y_mesh, z_mesh = sph2cart(r=np.ones_like(polar_mesh),
p=np.deg2rad(polar_mesh),
a=np.deg2rad(azimuth_mesh))
try:
np.warnings.filterwarnings('ignore', # mpl uses depreciated numpy
category=np.VisibleDeprecationWarning)
except AttributeError:
pass
ax.plot_wireframe(x_mesh, y_mesh, z_mesh, colors='k', linewidth=0.25)
# color map declarations
m = cm.ScalarMappable(cmap=self.heat_palette)
m.set_array(heat_mesh)
m.set_clim(*self.heat_limits)
pyplot.colorbar(m, ax=ax, fraction=0.06, pad=0.0, shrink=0.9)
norm = colors.Normalize(*self.heat_limits)
# draw (100), (010), (010) axes
len_ = 1.25
xa, xb = self.x_axis, [_ * len_ for _ in self.x_axis]
ya, yb = self.y_axis, [_ * len_ for _ in self.y_axis]
za, zb = self.z_axis, [_ * len_ for _ in self.z_axis]
ax.add_line(art3d.Line3D((xa[0], xb[0]), (xa[1], xb[1]), (xa[2], xb[2]),
color='r', linewidth=5, zorder=9))
ax.text(*xb, '(100)', zorder=10, color='k')
ax.add_line(art3d.Line3D((ya[0], yb[0]), (ya[1], yb[1]), (ya[2], yb[2]),
color='g', linewidth=5, zorder=9))
ax.text(*yb, '(010)', zorder=10, color='k')
ax.add_line(art3d.Line3D((za[0], zb[0]), (za[1], zb[1]), (za[2], zb[2]),
color='b', linewidth=5, zorder=9))
ax.text(*zb, '(001)', zorder=10, color='k')
# draw focus points
if self.focus:
xf, yf, zf = np.array(self.focus).T
ax.plot(xf, yf, zf, linestyle='none', marker='D', markersize=10,
markerfacecolor='none', markeredgewidth=2,
markeredgecolor='k', zorder=8)
# prepare smaller heat mesh for polygon centers and plot the heatmap
face_heat_mesh = (heat_mesh[1:, 1:] + heat_mesh[1:, :-1] +
heat_mesh[:-1, 1:] + heat_mesh[:-1, :-1]) / 4
color_mesh = self.heat_palette(norm(face_heat_mesh))
for item in [fig, ax]:
item.patch.set_visible(False)
ax.plot_surface(x_mesh, y_mesh, z_mesh, rstride=1, cstride=1,
cmap=self.heat_palette, linewidth=0,
antialiased=False, facecolors=color_mesh)
pyplot.subplots_adjust(left=0.0, bottom=0.0, right=0.95, top=1.0)
pyplot.savefig(png_path, dpi=100, format='png', bbox_inches=None)
artist_factory = ArtistFactory()
artist_factory.register(name='gnuplot_angular_heatmap_artist',
artist=GnuplotAngularHeatmapArtist)
artist_factory.register(name='matplotlib_angular_heatmap_artist',
artist=MatplotlibAngularHeatmapArtist)