Initial project version

This commit is contained in:
sebastiankloth
2022-04-11 11:01:13 +02:00
commit f40f2badd8
18 changed files with 2012 additions and 0 deletions

77
tudplot/__init__.py Executable file
View File

@ -0,0 +1,77 @@
import os
import numpy
import matplotlib as mpl
from matplotlib import pyplot
from cycler import cycler
from .xmgrace import export_to_agr, load_agr_data
from .tud import tudcolors, nominal_colors, sequential_colors
from .utils import facet_plot, CurvedText as curved_text
def activate(scheme='b', full=False, sequential=False, cmap='blue-red', **kwargs):
"""
Activate the tud design.
Args:
scheme (opt.): Color scheme to activate, default is 'b'.
full (opt.):
Activate the full color palette. If False a smaller color palette is used.
If a number N is given, N colors will be chosen based on a interpolation of
all tudcolors.
sequential (opt.): Activate a number of sequential colors from a color map.
cmap (opt.):
Colormap to use for sequential colors, can be either from `~tudplot.tud.cmaps`
or any matplotlib color map. Range of the color map values can be given as
cmap_min and cmap_max, respectively.
**kwargs: Any matplotlib rc paramter may be given as keyword argument.
"""
mpl.pyplot.style.use(os.path.join(os.path.dirname(__file__), 'tud.mplstyle'))
if full:
if isinstance(full, int):
cmap = mpl.colors.LinearSegmentedColormap.from_list('tud{}'.format(scheme),
tudcolors[scheme])
colors = [cmap(x) for x in numpy.linspace(0, 1, full)]
else:
colors = tudcolors[scheme]
elif sequential:
colors = sequential_colors(sequential, cmap=cmap, min=kwargs.pop('cmap_min', 0),
max=kwargs.pop('cmap_max', 1))
else:
colors = nominal_colors[scheme]
mpl.rcParams['axes.prop_cycle'] = cycler('color', colors)
def saveagr(filename, figure=None, convert_latex=True):
"""
Save the current figure in xmgrace format.
Args:
filename: Agrfile to save the figure to
figure (opt.):
Figure that will be saved, if not given the current figure is saved
"""
figure = figure or pyplot.gcf()
figure.canvas.draw()
export_to_agr(figure, filename, convert_latex=convert_latex)
def markfigure(x, y, s, ax=None, **kwargs):
if ax is None:
ax = pyplot.gca()
kwargs['transform'] = ax.transAxes
kwargs['ha'] = 'center'
kwargs['va'] = 'center'
# kwargs.setdefault('fontsize', 'large')
ax.text(x, y, s, **kwargs)
def marka(x, y):
markfigure(x, y, '(a)')
def markb(x, y):
markfigure(x, y, '(b)')

157
tudplot/altair.py Executable file
View File

@ -0,0 +1,157 @@
import altair
from altair import Config, Color, Shape, Column, Row, Encoding, Scale, Axis
from random import randint
import os
import logging
import matplotlib.pyplot as plt
from .tud import nominal_colors, full_colors
def filter_nulltime_json(data):
if 'time' in data:
data = data[data.time > 0]
return altair.pipe(data, altair.to_json)
altair.renderers.enable('notebook')
altair.data_transformers.register('json_logtime', filter_nulltime_json)
altair.data_transformers.enable('json_logtime')
def my_theme(*args, **kwargs):
return {
'range': {
'ordinal': altair.VgScheme('viridis'),
'ramp': {'scheme': 'viridis'}
}
}
altair.themes.register('my-theme', my_theme)
altair.themes.enable('my-theme')
class BaseMixin(Encoding):
def __init__(self, *args, **kwargs):
kwargs.setdefault('scale', altair.Scale(zero=False))
super().__init__(*args, **kwargs)
class LogMixin(BaseMixin):
def __init__(self, *args, **kwargs):
kwargs['scale'] = altair.Scale(type='log')
super().__init__(*args, **kwargs)
class X(altair.X, BaseMixin):
pass
class Y(altair.Y, BaseMixin):
pass
class LogX(altair.X, LogMixin):
pass
class LogY(altair.Y, LogMixin):
pass
class DataHandler(altair.Data):
def __init__(self, df):
self._filename = '.altair.json'
with open(self._filename, 'w') as f:
f.write(df.to_json())
super().__init__(url=self._filename)
class Chart(altair.Chart):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def encode(self, *args, color=None, **kwargs):
if isinstance(color, str):
if color.endswith(':F'):
field = color[:-2]
color = color.replace(':F', ':N')
self.configure_scale(nominalColorRange=full_colors(len(set(self._data[field]))))
for arg in args:
if isinstance(arg, altair.X):
kwargs['x'] = arg
elif isinstance(arg, altair.Y):
kwargs['y'] = arg
return super().encode(color=color, **kwargs)
def to_mpl(self):
d = self.to_dict()
fmt = 'o' if d.get('mark', 'point') is 'point' else '-'
def encode(data, encoding, **kwargs):
logging.debug(str(kwargs))
if 'column' in encoding:
channel = encoding.pop('column')
ncols = len(data[channel.get('field')].unique())
for col, (column, df) in enumerate(data.groupby(channel.get('field'))):
ax = plt.gca() if col > 0 else None
plt.subplot(kwargs.get('nrows', 1), ncols, col + 1, sharey=ax).set_title(column)
encode(df, encoding.copy(), secondary_column=col > 0, **kwargs.copy())
elif 'color' in encoding:
channel = encoding.pop('color')
if channel.get('type') == 'quantitative':
colors = full_colors(len(data[channel.get('field')].unique()))
else:
colors = nominal_colors['b']
while len(colors) < len(data[channel.get('field')].unique()):
colors *= 2
for color, (column, df) in zip(colors, data.groupby(channel.get('field'))):
if 'label' in kwargs:
label = kwargs.pop('label') + ', {}'.format(column)
else:
label = str(column)
encode(df, encoding.copy(), color=color, label=label, **kwargs.copy())
elif 'shape' in encoding:
channel = encoding.pop('shape')
markers = ['h', 'v', 'o', 's', '^', 'D', '<', '>']
while len(markers) < len(data[channel.get('field')].unique()):
markers *= 2
logging.debug(str(data[channel.get('field')].unique()))
for marker, (column, df) in zip(markers, data.groupby(channel.get('field'))):
if 'label' in kwargs:
label = kwargs.pop('label') + ', {}'.format(column)
else:
label = str(column)
encode(df, encoding.copy(), marker=marker, label=label, **kwargs.copy())
else:
x_field = encoding.get('x').get('field')
y_field = encoding.get('y').get('field')
plt.xlabel(x_field)
if not kwargs.pop('secondary_column', False):
plt.ylabel(y_field)
else:
plt.tick_params(axis='y', which='both', labelleft='off', labelright='off')
if 'scale' in encoding.get('x'):
plt.xscale(encoding['x']['scale'].get('type', 'linear'))
if 'scale' in encoding.get('y'):
plt.yscale(encoding['y']['scale'].get('type', 'linear'))
plt.plot(data[x_field], data[y_field], fmt, **kwargs)
plt.legend(loc='best', fontsize='small')
encode(self._data, d.get('encoding'))
plt.tight_layout(pad=0.5)
class Arrhenius(Chart):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# self.transform_data(calculate=[Formula(field='1000K/T', expr='1000/datum.T')])
self.data['1000 K / T'] = 1000 / self.data['T']
self.encode(x=X('1000 K / T'))

53
tudplot/tex2grace.py Executable file
View File

@ -0,0 +1,53 @@
from collections import OrderedDict
import logging
import re
patterns = OrderedDict()
patterns['\$'] = ''
patterns['¹'] = r'\\S1\\N'
patterns['²'] = r'\\S2\\N'
patterns['³'] = r'\\S3\\N'
patterns[''] = r'\\S4\\N'
patterns[''] = r'\\S5\\N'
patterns[''] = r'\\S6\\N'
patterns[''] = r'\\S7\\N'
patterns[''] = r'\\S8\\N'
patterns[''] = r'\\S9\\N'
patterns[''] = r'\\S-\\N'
patterns[r'\\star'] = '*'
patterns[r'\\,'] = r'\\-\\- \\+\\+'
patterns[r'\\[;:.]'] = ''
patterns[r'\\math(?:tt|sf|it|rm)({[^}]+})'] = r'\1'
patterns[r'\^({[^}]+}|.)'] = r'\\S\1\\N'
patterns[r'\_({[^}]+}|.)'] = r'\\s\1\\N'
# Remove any left over latex groups as the last step
patterns[r'[{}]'] = ''
# now any patterns that do need braces...
patterns[r'\\tilde '] = r'\\v{{0.1}}~\\v{{-0.1}}\\M{{1}}'
# Greek letters in xmgrace are written by switching to symbol-font:
# "\x a\f{}" will print an alpha and switch back to normal font
greek = {
'alpha': 'a', 'beta': 'b', 'gamma': 'g', 'delta': 'd', 'epsilon': 'e', 'zeta': 'z',
'eta': 'h', 'theta': 'q', 'iota': 'i', 'kappa': 'k', 'lambda': 'l', 'mu': 'm',
'nu': 'n', 'xi': 'x', 'omicron': 'o', 'pi': 'p', 'rho': 'r', 'sigma': 's',
'tau': 't', 'upsilon': 'u', 'phi': 'f', 'chi': 'c', 'psi': 'y', 'omega': 'w',
'varphi': 'j', 'varepsilon': 'e', 'vartheta': 'J', 'varrho': 'r',
'Phi': 'F',
'langle': r'\#{{e1}}', 'rangle': r'\#{{f1}}', 'infty': r'\\c%\\C', 'cdot': r'\#{{d7}}',
'sqrt': r'\#{{d6}}', 'propto': r'\#{{b5}}', 'approx': r'\#{{bb}}'
}
for latex, xmg in greek.items():
patt = r'\\{}'.format(latex)
repl = r'\\x {}\\f{{{{}}}}'.format(xmg)
patterns[patt] = repl
def latex_to_xmgrace(string):
logging.debug('Convert to xmgrace: {}'.format(string))
for patt, repl in patterns.items():
string = re.sub(patt, repl, string)
logging.debug('To -> {}'.format(string))
return string

7
tudplot/tud.mplstyle Executable file
View File

@ -0,0 +1,7 @@
font.size: 16
lines.linewidth: 1.5
lines.markeredgewidth: 1.5
lines.markersize: 6
markers.fillstyle: full
figure.figsize: 8, 6
savefig.dpi: 300

42
tudplot/tud.py Normal file
View File

@ -0,0 +1,42 @@
import re
import matplotlib as mpl
import numpy
tudcolors = {
'a': ('#5D85C3', '#009CDA', '#50B695', '#AFCC50', '#DDDF48', '#FFE05C',
'#F8BA3C', '#EE7A34', '#E9503E', '#C9308E', '#804597'),
'b': ('#005AA9', '#0083CC', '#009D81', '#99C000', '#C9D400', '#FDCA00',
'#F5A300', '#EC6500', '#E6001A', '#A60084', '#721085'),
'c': ('#004E8A', '#00689D', '#008877', '#7FAB16', '#B1BD00', '#D7AC00',
'#D28700', '#CC4C03', '#B90F22', '#951169', '#611C73'),
'd': ('#243572', '#004E73', '#00715E', '#6A8B22', '#99A604', '#AE8E00',
'#BE6F00', '#A94913', '#961C26', '#732054', '#4C226A'),
}
# Store each color value in the dict as defined in the Style-Guide (e.g. tud9c)
tudcolors.update({'tud{}{}'.format(i + 1, s): col for s in tudcolors for i, col in enumerate(tudcolors[s])})
color_maps = {
'blue-red': (tudcolors['tud1b'], tudcolors['tud9b']),
'black-green': ('black', tudcolors['tud5c']),
'violett-green': (tudcolors['tud11c'], tudcolors['tud5b'])
}
nominal_colors = {scheme: [tudcolors[scheme][i] for i in [1, 8, 3, 9, 6, 2]] for scheme in 'abcd'}
def full_colors(N, scheme='b'):
cmap = mpl.colors.LinearSegmentedColormap.from_list('tud{}'.format(scheme), tudcolors[scheme])
return ['#{:02x}{:02x}{:02x}'.format(*cmap(x, bytes=True)[:3]) for x in numpy.linspace(0, 1, N)]
def sequential_colors(N, cmap='blue-red', min=0, max=1):
if cmap in color_maps:
cmap = mpl.colors.LinearSegmentedColormap.from_list('tud_{}'.format(cmap), color_maps[cmap])
elif '-' in cmap:
cols = [tudcolors[k] if 'tud' in k else k for k in cmap.split('-')]
cmap = mpl.colors.LinearSegmentedColormap.from_list(cmap, cols)
else:
cmap = mpl.pyplot.get_cmap(cmap)
return ['#{:02x}{:02x}{:02x}'.format(*cmap(x, bytes=True)[:3]) for x in numpy.linspace(min, max, N)]

220
tudplot/utils.py Normal file
View File

@ -0,0 +1,220 @@
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from collections import Iterable
from matplotlib.cbook import flatten
from itertools import cycle
def facet_plot(dframe, facets, props, ydata, layout=None, newfig=True, figsize=None,
legend=True, individual_legends=False, hide_additional_axes=True, zorder='default', **kwargs):
if newfig:
nr_facets = len(dframe.groupby(facets))
if layout is None:
for i in range(2, nr_facets // 2):
if nr_facets % i == 0:
layout = (nr_facets // i, i)
break
if layout is None:
n = int(np.ceil(nr_facets / 2))
layout = (n, 2)
fig, axs = plt.subplots(
nrows=layout[0],
ncols=layout[1],
sharex=True, sharey=True, figsize=figsize
)
if hide_additional_axes:
for ax in fig.axes[nr_facets:]:
ax.set_axis_off()
else:
fig = plt.gcf()
axs = fig.axes
cycl = cycle(plt.rcParams['axes.prop_cycle'])
prop_styles = {ps: next(cycl) for ps, _ in dframe.groupby(props)}
if zorder is 'default':
dz = 1
zorder = 0
elif zorder is 'reverse':
dz = -1
zorder = 0
else:
dz = 0
if legend:
ax0 = fig.add_subplot(111, frame_on=False, zorder=-9999)
ax0.set_axis_off()
plot_kwargs = kwargs.copy()
for k in ['logx', 'logy', 'loglog']:
plot_kwargs.pop(k, None)
for l, p in prop_styles.items():
ax0.plot([], label=str(l), **p, **plot_kwargs)
ax0.legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize='x-small')
for ax, (ps, df) in zip(flatten(axs), dframe.groupby(facets, squeeze=False)):
for prop, df_prop in df.groupby(props):
df_prop[ydata].plot(ax=ax, label=str(prop), zorder=zorder, **prop_styles[prop], **kwargs)
zorder += dz
# ax.title(0.5, 0.1, '{},{}'.format(*ps), transform=ax.transAxes, fontsize='small')
ax.set_title('; '.join([str(x) for x in ps]) if isinstance(ps, tuple) else str(ps), fontsize='x-small')
if individual_legends:
ax.legend(fontsize='x-small')
plt.sca(ax)
rect = (0, 0, 0.85, 1) if legend else (0, 0, 1, 1)
plt.tight_layout(rect=rect, pad=0.1)
return fig, axs
class CurvedText(mpl.text.Text):
"""A text object that follows an arbitrary curve."""
def __init__(self, x, y, text, axes, **kwargs):
super(CurvedText, self).__init__(x[0],y[0],' ', axes, **kwargs)
axes.add_artist(self)
# # saving the curve:
self.__x = x
self.__y = y
self.__zorder = self.get_zorder()
# # creating the text objects
self.__Characters = []
for c in text:
t = mpl.text.Text(0, 0, c, **kwargs)
# resetting unnecessary arguments
t.set_ha('center')
t.set_rotation(0)
t.set_zorder(self.__zorder +1)
self.__Characters.append((c,t))
axes.add_artist(t)
# # overloading some member functions, to assure correct functionality
# # on update
def set_zorder(self, zorder):
super(CurvedText, self).set_zorder(zorder)
self.__zorder = self.get_zorder()
for c,t in self.__Characters:
t.set_zorder(self.__zorder+1)
def draw(self, renderer, *args, **kwargs):
"""
Overload of the Text.draw() function. Do not do
do any drawing, but update the positions and rotation
angles of self.__Characters.
"""
self.update_positions(renderer)
def update_positions(self,renderer):
"""
Update positions and rotations of the individual text elements.
"""
# preparations
# # determining the aspect ratio:
# # from https://stackoverflow.com/a/42014041/2454357
# # data limits
xlim = self.axes.get_xlim()
ylim = self.axes.get_ylim()
# # Axis size on figure
figW, figH = self.axes.get_figure().get_size_inches()
# # Ratio of display units
_, _, w, h = self.axes.get_position().bounds
# # final aspect ratio
aspect = ((figW * w)/(figH * h))*(ylim[1]-ylim[0])/(xlim[1]-xlim[0])
# points of the curve in figure coordinates:
x_fig,y_fig = (
np.array(l) for l in zip(*self.axes.transData.transform([
(i,j) for i,j in zip(self.__x,self.__y)
]))
)
# point distances in figure coordinates
x_fig_dist = (x_fig[1:]-x_fig[:-1])
y_fig_dist = (y_fig[1:]-y_fig[:-1])
r_fig_dist = np.sqrt(x_fig_dist**2+y_fig_dist**2)
# arc length in figure coordinates
l_fig = np.insert(np.cumsum(r_fig_dist),0,0)
# angles in figure coordinates
rads = np.arctan2((y_fig[1:] - y_fig[:-1]),(x_fig[1:] - x_fig[:-1]))
degs = np.rad2deg(rads)
rel_pos = 10
for c,t in self.__Characters:
# finding the width of c:
t.set_rotation(0)
t.set_va('center')
bbox1 = t.get_window_extent(renderer=renderer)
w = bbox1.width
h = bbox1.height
# ignore all letters that don't fit:
if rel_pos+w/2 > l_fig[-1]:
t.set_alpha(0.0)
rel_pos += w
continue
elif c != ' ':
t.set_alpha(1.0)
# finding the two data points between which the horizontal
# center point of the character will be situated
# left and right indices:
il = np.where(rel_pos+w/2 >= l_fig)[0][-1]
ir = np.where(rel_pos+w/2 <= l_fig)[0][0]
# if we exactly hit a data point:
if ir == il:
ir += 1
# how much of the letter width was needed to find il:
used = l_fig[il]-rel_pos
rel_pos = l_fig[il]
# relative distance between il and ir where the center
# of the character will be
fraction = (w/2-used)/r_fig_dist[il]
# # setting the character position in data coordinates:
# # interpolate between the two points:
x = self.__x[il]+fraction*(self.__x[ir]-self.__x[il])
y = self.__y[il]+fraction*(self.__y[ir]-self.__y[il])
# getting the offset when setting correct vertical alignment
# in data coordinates
t.set_va(self.get_va())
bbox2 = t.get_window_extent(renderer=renderer)
bbox1d = self.axes.transData.inverted().transform(bbox1)
bbox2d = self.axes.transData.inverted().transform(bbox2)
dr = np.array(bbox2d[0]-bbox1d[0])
# the rotation/stretch matrix
rad = rads[il]
rot_mat = np.array([
[np.cos(rad), np.sin(rad)*aspect],
[-np.sin(rad)/aspect, np.cos(rad)]
])
# # computing the offset vector of the rotated character
drp = np.dot(dr,rot_mat)
# setting final position and rotation:
t.set_position(np.array([x,y])+drp)
t.set_rotation(degs[il])
t.set_va('center')
t.set_ha('center')
# updating rel_pos to right edge of character
rel_pos += w-used

461
tudplot/xmgrace.py Executable file
View File

@ -0,0 +1,461 @@
import re
import logging
from collections import OrderedDict
from matplotlib.colors import ColorConverter
from matplotlib.cbook import is_string_like
import numpy as np
import matplotlib as mpl
from .tud import tudcolors
from .tex2grace import latex_to_xmgrace
def indexed(list, default=None):
def index(arg):
for i, v in enumerate(list):
if (isinstance(v, tuple) and arg in v) or arg == v:
return i
return default
return index
def escapestr(s):
raw_map = {8: r'\b', 7: r'\a', 12: r'\f', 10: r'\n', 13: r'\r', 9: r'\t', 11: r'\v'}
return r''.join(i if ord(i) > 32 else raw_map.get(ord(i), i) for i in s)
def get_viewport_coords(artist):
"""
Get the viewport coordinates of an artist.
"""
fxy = artist.figure.get_size_inches()
fxy /= fxy.min()
trans = artist.figure.transFigure.inverted()
return trans.transform(artist.get_window_extent()) * fxy[np.newaxis, :]
def get_world_coords(artist):
"""
Get the world coordinates of an artist.
"""
trans = artist.axes.transData.inverted()
return trans.transform(artist.get_window_extent())
def get_world(axis):
xmin, xmax = axis.get_xlim()
ymin, ymax = axis.get_ylim()
return '{}, {}, {}, {}'.format(xmin, ymin, xmax, ymax)
def get_view(axis):
box = axis.get_position()
fx, fy = axis.figure.get_size_inches()
sx = fx / min(fx, fy)
sy = fy / min(fx, fy)
c = np.array([box.xmin*sx, box.ymin*sy, box.xmax*sx, box.ymax*sy])
return '{:.3}, {:.3}, {:.3}, {:.3}'.format(*c)
def get_major_ticks(dim):
def get_major_dticks(axis):
ticks = getattr(axis, 'get_{}ticks'.format(dim))()
scale = getattr(axis, 'get_{}scale'.format(dim))()
if scale is 'log':
value = (ticks[1:] / ticks[:-1]).mean()
else:
value = (ticks[1:] - ticks[:-1]).mean()
return value
return get_major_dticks
agr_attr_lists = {
# Linestyles in xmgrace: None are styles that are by default
# not defined in matplotlib (longer dashes and double dots)
# First entry should always be None, since index starts at 1
'linestyle': ['None', '-', ':', '--', None, '-.', None, None, None],
'marker': ['None', 'o', 's', 'd', '^', '<', 'v', '>', '+', 'x', '*']
}
def get_ticklabels_on(dim):
def get_ticklabels(axis):
tl = getattr(axis, f'{dim}axis').get_ticklabels()
return 'off' if all([x.get_text() == '' for x in tl]) else 'on'
return get_ticklabels
def get_legend(axis):
if axis.get_legend() is not None:
return 'on'
else:
return 'off'
def get_legend_position(axis):
leg = axis.get_legend()
if leg is not None:
return '{:.3f}, {:.3f}'.format(*get_viewport_coords(leg).diagonal())
else:
return '0, 0'
def get_text_position(text):
#return '{:.3f}, {:.3f}'.format(*get_viewport_coords(text)[0])
return '{:.3f}, {:.3f}'.format(*get_viewport_coords(text).mean(axis=0))
def get_arrow_coordinates(text):
arrow = text.arrow_patch
trans = text.axes.transData.inverted()
xy = trans.transform(arrow.get_path().vertices[[0, 2]])
#xy = get_viewport_coords(text)
return '{:.3f}, {:.3f}, {:.3f}, {:.3f}'.format(*xy[0], *xy[1])
class StaticAttribute:
"""
A static attribute that just writes a line to the agr file if it is set.
Functions also as a base class for other attribute classes.
"""
def __init__(self, key, fmt):
"""
Args:
key: The name of the attribute.
fmt: The string which is written to the agr file.
"""
self.key = key
self.fmt = fmt
def format(self, source=None, **kwargs):
"""
Return the (formatted) string of the attribute.
Args:
source: The python object, which is only included here for signature reasons
"""
return self.fmt
class ValueAttribute(StaticAttribute):
"""
An attribute which writes a key value pair into the agr file.
The agr string has the format: '{fmt} {value}'
"""
attr_lists = {
'linestyle': ('None', '-', ':', '--', None, '-.', None, None, None),
'marker': (('', 'None'), 'o', 's', ('d', 'D'), '^', '<', 'v', '>', '+', 'x', '*'),
'fillstyle': ('none', 'full', ),
'color': ['white', 'black'],
}
def reset_colors():
ValueAttribute.attr_lists['color'] = ['white', 'black']
def _get_value(self, source, convert_latex=True):
value = getattr(source, 'get_{}'.format(self.key))()
if isinstance(value, str):
if convert_latex:
value = latex_to_xmgrace(value)
else:
value = value.replace(r'{}', r'{{}}').replace('{{{}}}', '{{}}')
if not self.index:
value = '"{}"'.format(value)
return value
def __init__(self, *args, index=None, function=None, condition=None):
"""
Args:
*args: Arguments of super().__init__()
index:
True if value should be mapped to an index. If this is a str this will
be used as the index lists key.
function: A function that is used to fetch the value from the source.
condition: A function that decides if the attribute is written to the agr file.
"""
super().__init__(*args)
if index:
if index is True:
self.index = self.key
else:
self.index = index
self.attr_lists.setdefault(self.index, [])
else:
self.index = False
if function is not None:
self._get_value = lambda x, **kwargs: function(x)
self.condition = condition or (lambda x: True)
def format(self, source, convert_latex=True, **kwargs):
value = self._get_value(source, convert_latex=convert_latex)
if not self.condition(value):
return None
if self.index:
attr_list = self.attr_lists[self.index]
index = indexed(attr_list)(str(value))
if index is None:
try:
attr_list.append(value)
index = attr_list.index(value)
except AttributeError:
print('index not found:', value, index, attr_list)
index = 1
value = index
logging.debug('fmt: {}, value: {}'.format(self.fmt, value))
return ' '.join([self.fmt, str(value)])
agr_line_attrs = [
StaticAttribute('hidden', 'hidden false'),
StaticAttribute('type', 'type xy'),
ValueAttribute('label', 'legend', condition=lambda lbl: re.search(r'\\sl\\Nine\d+', lbl) is None),
ValueAttribute('linestyle', 'line linestyle', index=True),
ValueAttribute('linewidth', 'line linewidth'),
ValueAttribute('color', 'line color', index=True),
ValueAttribute('marker', 'symbol', index=True),
ValueAttribute('fillstyle', 'symbol fill pattern', index=True),
ValueAttribute('markeredgecolor', 'symbol color', index='color'),
ValueAttribute('markerfacecolor', 'symbol fill color', index='color'),
ValueAttribute('markeredgewidth', 'symbol linewidth'),
]
agr_axis_attrs = [
StaticAttribute('xaxis', 'frame background pattern 1'),
StaticAttribute('xaxis', 'xaxis label char size 1.0'),
StaticAttribute('yaxis', 'yaxis label char size 1.0'),
StaticAttribute('xaxis', 'xaxis ticklabel char size 1.0'),
StaticAttribute('yaxis', 'yaxis ticklabel char size 1.0'),
ValueAttribute('world', 'world', function=get_world),
ValueAttribute('view', 'view', function=get_view),
ValueAttribute('title', 'subtitle'),
ValueAttribute('xlabel', 'xaxis label'),
ValueAttribute('ylabel', 'yaxis label'),
ValueAttribute('xscale', 'xaxes scale Logarithmic', condition=lambda scale: 'log' in scale),
ValueAttribute('xscale', 'xaxis ticklabel format power', condition=lambda scale: 'log' in scale),
ValueAttribute('xscale', 'xaxis ticklabel prec 0', condition=lambda scale: 'log' in scale),
ValueAttribute('xscale', 'xaxis tick minor ticks', function=lambda ax: 9 if 'log' in ax.get_xscale() else 4),
ValueAttribute('yscale', 'yaxes scale Logarithmic', condition=lambda scale: 'log' in scale),
ValueAttribute('yscale', 'yaxis ticklabel format power', condition=lambda scale: 'log' in scale),
ValueAttribute('yscale', 'yaxis ticklabel prec 0', condition=lambda scale: 'log' in scale),
ValueAttribute('xscale', 'yaxis tick minor ticks', function=lambda ax: 9 if 'log' in ax.get_yscale() else 4),
ValueAttribute('xticks', 'xaxis tick major', function=get_major_ticks('x')),
ValueAttribute('yticks', 'yaxis tick major', function=get_major_ticks('y')),
ValueAttribute('xticklabels', 'xaxis ticklabel', function=get_ticklabels_on('x')),
ValueAttribute('yticklabels', 'yaxis ticklabel', function=get_ticklabels_on('y')),
ValueAttribute('xlabelposition', 'xaxis label place',
function=lambda ax: 'opposite' if ax.xaxis.get_label_position() == 'top' else 'normal'),
ValueAttribute('xtickposition', 'xaxis ticklabel place',
function=lambda ax: 'opposite' if all([t.get_position()[1] >= 0.9 for t in ax.xaxis.get_ticklabels()]) else 'normal'),
ValueAttribute('ylabelposition', 'yaxis label place',
function=lambda ax: 'opposite' if ax.yaxis.get_label_position() == 'right' else 'normal'),
ValueAttribute('ytickposition', 'yaxis ticklabel place',
function=lambda ax: 'opposite' if all([t.get_position()[0] >= 0.9 for t in ax.yaxis.get_ticklabels()]) else 'normal'),
# tax.yaxis.get_ticks_position() == 'right' else 'normal'),
ValueAttribute('legend', 'legend', function=get_legend),
StaticAttribute('legend', 'legend loctype view'),
StaticAttribute('legend', 'legend char size 1.0'),
ValueAttribute('legend', 'legend', function=get_legend_position)
]
agr_text_attrs = [
StaticAttribute('string', 'on'),
StaticAttribute('string', 'loctype view'),
StaticAttribute('string', 'char size 1.0'),
ValueAttribute('position', '', function=get_text_position),
ValueAttribute('text', 'def')
]
agr_arrow_attrs = [
StaticAttribute('line', 'on'),
StaticAttribute('line', 'loctype world'),
StaticAttribute('line', 'color 1'),
StaticAttribute('line', 'linewidth 2'),
StaticAttribute('line', 'linestyle 1'),
StaticAttribute('line', 'arrow 2'),
ValueAttribute('line', '', function=get_arrow_coordinates),
]
class AgrFile:
head = '@version 50125\n'
body = tail = ''
indent = 0
kwargs = {}
def writeline(self, text, part='body', **kwargs):
self.kwargs = {**self.kwargs, **kwargs}
content = getattr(self, part)
content += '@' + ' ' * self.indent + escapestr(text.format(**self.kwargs)) + '\n'
setattr(self, part, content)
def writedata(self, data):
self.tail += '@target {axis}.{line}\n@type xy\n'.format(**self.kwargs)
for x, y in data:
if np.isfinite([x, y]).all():
self.tail += '{} {}\n'.format(x, y)
self.tail += '&\n'
def save(self, filename):
with open(filename, 'w') as file:
file.write(self.head)
file.write(self.body)
file.write(self.tail)
def _process_attributes(attrs, source, agr, prefix=''):
for attr, attr_dict in attrs.items():
attr_type = attr_dict['type']
if 'static' in attr_type:
value = ''
elif 'function' in attr_type:
value = attr_dict['function'](source)
else:
value = getattr(source, 'get_{}'.format(attr))()
if 'condition' in attr_dict:
if not attr_dict['condition'](value):
continue
if is_string_like(value):
value = latex_to_xmgrace(value)
if 'index' in attr_type:
attr_list = agr_attr_lists[attr_dict.get('maplist', attr)]
index = indexed(attr_list)(value)
if index is None:
if 'map' in attr_type:
attr_list.append(value)
index = attr_list.index(value)
else:
index = 1
value = index
agr.writeline(prefix + attr_dict['fmt'], attr=attr, value=value)
def process_attributes(attrs, source, agr, prefix='', **kwargs):
for attr in attrs:
fmt = attr.format(source, **kwargs)
if fmt is not None:
agr.writeline(prefix + fmt)
def export_to_agr(figure, filename, **kwargs):
"""
Export a matplotlib figure to xmgrace format.
"""
ValueAttribute.reset_colors()
cc = ColorConverter()
agr = AgrFile()
# agr_attr_lists['color'] = ['white', 'black']
# agr_colors =
papersize = figure.get_size_inches()*120
agr.writeline('page size {}, {}'.format(*papersize))
agr.writeline('default char size {}'.format(mpl.rcParams['font.size'] / 12))
for i, axis in enumerate(figure.axes):
agr_axis = 'g{}'.format(i)
agr.indent = 0
agr.writeline('{axis} on', axis=agr_axis)
agr.writeline('{axis} hidden false')
agr.writeline('{axis} type XY')
agr.writeline('{axis} stacked false')
agr.writeline('with {axis}')
agr.indent = 4
process_attributes(agr_axis_attrs, axis, agr, **kwargs)
for j, line in enumerate(axis.lines):
agr.kwargs['line'] = 's{}'.format(j)
process_attributes(agr_line_attrs, line, agr, '{line} ', **kwargs)
agr.writedata(line.get_xydata())
for text in axis.texts:
agr.indent = 0
agr.writeline('with string')
agr.indent = 4
process_attributes(agr_text_attrs, text, agr, 'string ', **kwargs)
# this is a text of an arrow-annotation
if hasattr(text, 'arrow_patch'):
agr.indent = 0
agr.writeline('with line')
agr.indent = 4
agr.writeline(f'line {agr_axis}')
process_attributes(agr_arrow_attrs, text, agr, 'line ', **kwargs)
agr.indent = 0
agr.writeline('line def')
agr.indent = 0
tudcol_rev = {}
for name, color in tudcolors.items():
if isinstance(color, str):
rgba, = cc.to_rgba_array(color)
tudcol_rev[tuple(rgba)] = name
for i, color in enumerate(ValueAttribute.attr_lists['color']):
# print(i, color)
if color is not 'none':
rgba, = cc.to_rgba_array(color)
rgb_tuple = tuple(int(255 * c) for c in rgba[:3])
color_name = tudcol_rev.get(tuple(rgba), color)
agr.writeline('map color {index} to {rgb}, "{color}"',
part='head', index=i, rgb=rgb_tuple, color=color_name)
agr.save(filename)
def load_agr_data(agrfile, nolabels=False):
"""
Load all named data sets from an agrfile.
"""
graphs = OrderedDict()
cur_graph = None
target = None
with open(agrfile, 'r', errors='replace') as f:
lines = f.readlines()
for org_line in lines:
line = org_line.lower()
if '@with' in line:
graph_id = line.split()[1]
if graph_id not in graphs:
graphs[graph_id] = {}
cur_graph = graphs[graph_id]
elif 'legend' in line and cur_graph is not None:
ma = re.search('([sS]\d+) .+ "(.*)"', org_line)
if ma is not None:
label = ma.group(2)
sid = ma.group(1).lower()
if label == '' or nolabels:
gid = [k for k, v in graphs.items() if v is cur_graph][0]
label = '{}.{}'.format(gid, sid)
cur_graph[sid] = {'label': label}
elif '@target' in line:
ma = re.search('(g\d+)\.(s\d+)', line.lower())
gid = ma.group(1)
sid = ma.group(2)
target = []
if sid not in graphs[gid]:
graphs[gid][sid] = {'label': '{}.{}'.format(gid, sid)}
graphs[gid][sid]['data'] = target
elif target is not None and '@type' in line:
continue
elif '&' in line:
target = None
elif target is not None:
target.append([float(d) for d in line.split()])
data = OrderedDict()
for _, graph in graphs.items():
for _, set in graph.items():
if 'data' in set:
data[set['label']] = np.array(set['data'])
else:
print(_, set)
data[set['label']] = np.empty((0,))
return data