356 lines
12 KiB
Python
356 lines
12 KiB
Python
import numpy as np
|
|
from matplotlib.colors import ColorConverter
|
|
from matplotlib.backends import backend_agg
|
|
import matplotlib.pyplot as plt
|
|
import textwrap
|
|
from itertools import zip_longest
|
|
from collections import defaultdict
|
|
|
|
from . import tex2grace
|
|
from .tex2grace import latex_to_xmgrace
|
|
|
|
|
|
def update_labels(labels, axis=None):
|
|
if axis is None:
|
|
axis = plt.gca()
|
|
for line, lb in zip_longest(axis.lines, labels, fillvalue=''):
|
|
line.set_label(lb)
|
|
|
|
|
|
def sanitize_strings(dic):
|
|
"""Do some sanitization for strings."""
|
|
for k in dic:
|
|
if isinstance(dic[k], str):
|
|
dic[k].replace('{', '{{')
|
|
dic[k].replace('}', '}}')
|
|
|
|
|
|
def get_world_coords(artist):
|
|
"""Get the world coordinates of an artist."""
|
|
trans = artist.axes.transData.inverted()
|
|
return trans.transform(artist.get_window_extent())
|
|
|
|
|
|
class AgrText:
|
|
fmt = """string on
|
|
string loctype view
|
|
string {position}
|
|
string char size {size:.2f}
|
|
string color {color}
|
|
string def "{value}"
|
|
"""
|
|
|
|
def get_position(self):
|
|
trans = self.agr_axis.agr_figure.figure.transFigure.inverted()
|
|
pos = trans.transform(self.text.get_window_extent())[0]
|
|
pos = (pos + self.agr_figure.offset) * self.agr_figure.pagescale
|
|
self.position = '{:.5f}, {:.5f}'.format(*pos)
|
|
|
|
def __init__(self, text, agr_axis):
|
|
self.text = text
|
|
self.agr_axis = agr_axis
|
|
self.agr_figure = agr_axis.agr_figure
|
|
|
|
self.value = latex_to_xmgrace(text.get_text())
|
|
self.size = text.get_fontsize() * self.agr_figure.fontscale
|
|
self.color = AgrLine.color_index(text.get_color())
|
|
self.get_position()
|
|
sanitize_strings(self.__dict__)
|
|
|
|
def __str__(self):
|
|
return self.fmt.format(**self.__dict__)
|
|
|
|
|
|
class AgrLine:
|
|
fmt = """hidden {hidden}
|
|
type {type}
|
|
legend "{label}"
|
|
line linestyle {linestyle}
|
|
line linewidth {linewidth}
|
|
line color {color}
|
|
symbol {marker}
|
|
symbol color {markeredgecolor}
|
|
symbol fill color {markerfacecolor}
|
|
symbol fill pattern {markerfill}
|
|
symbol linewidth {linewidth}
|
|
"""
|
|
width_scale = 2
|
|
cc = ColorConverter()
|
|
linestyles = {'None': 0, '-': 1, ':': 2, '--': 3, '-.': 5}
|
|
markers = defaultdict(
|
|
lambda: 1,
|
|
{'': 0, 'None': 0, 'o': 1, 's': 2, 'd': 3, '^': 4, '<': 5, 'v': 6, '>': 7, '+': 8, 'x': 9, '*': 10}
|
|
)
|
|
fillstyles = ('none', 'full', 'right', 'left', 'bottom', 'top')
|
|
colors = ['white', 'black']
|
|
|
|
def color_index(col):
|
|
if col not in AgrLine.colors:
|
|
AgrLine.colors.append(col)
|
|
return AgrLine.colors.index(col)
|
|
|
|
@property
|
|
def data(self):
|
|
o = '@type xy\n'
|
|
for x, y in self.line.get_xydata():
|
|
o += '{} {}\n'.format(x, y)
|
|
o += '&'
|
|
return o
|
|
|
|
def get_label(self):
|
|
lbl = self.line.get_label()
|
|
self.label = latex_to_xmgrace(lbl) if not lbl.startswith('_line') else ''
|
|
|
|
def get_linestyle(self):
|
|
self.linestyle = self.linestyles[self.line.get_linestyle()]
|
|
|
|
def get_linewidth(self):
|
|
self.linewidth = self.line.get_linewidth() * self.width_scale
|
|
|
|
def get_color(self):
|
|
self.color = AgrLine.color_index(self.line.get_color())
|
|
|
|
def get_marker(self):
|
|
mk = self.line.get_marker()
|
|
self.marker = self.markers[mk] if mk in self.markers else 1
|
|
mfc = self.line.get_markerfacecolor()
|
|
self.markerfacecolor = AgrLine.color_index(mfc)
|
|
mec = self.line.get_markeredgecolor()
|
|
self.markeredgecolor = AgrLine.color_index(mec)
|
|
self.markeredgewidth = self.line.get_markeredgewidth() * self.width_scale
|
|
self.markerfill = self.fillstyles.index(self.line.get_fillstyle())
|
|
|
|
def __init__(self, line, agr_axis):
|
|
self.agr_axis = agr_axis
|
|
self.line = line
|
|
self.hidden = 'false'
|
|
self.type = 'xy'
|
|
|
|
# run all get_ methods
|
|
for d in dir(self):
|
|
if d.startswith('get_'):
|
|
getattr(self, d)()
|
|
|
|
sanitize_strings(self.__dict__)
|
|
|
|
def __str__(self):
|
|
return self.fmt.format(**self.__dict__)
|
|
|
|
|
|
class AgrAxis:
|
|
fmt = """world {world}
|
|
view {view}
|
|
title {title}
|
|
yaxes scale {yscale}
|
|
yaxis tick major {yticks}
|
|
xaxis label "{xlabel}"
|
|
xaxis label place {xlabelpos}
|
|
xaxis label char size {labelsize}
|
|
xaxis tick major {xticks}
|
|
xaxis ticklabel {xticklabel}
|
|
xaxis ticklabel place {xticklabelpos}
|
|
xaxis ticklabel char size {ticklabelsize}
|
|
xaxes scale {xscale}
|
|
yaxis label "{ylabel}"
|
|
yaxis label place {ylabelpos}
|
|
yaxis label char size {labelsize}
|
|
yaxis ticklabel {yticklabel}
|
|
yaxis ticklabel place {yticklabelpos}
|
|
yaxis ticklabel char size {ticklabelsize}
|
|
legend {legend}
|
|
legend loctype world
|
|
legend {legend_pos}
|
|
legend char size {legend_fontsize}
|
|
"""
|
|
|
|
def get_world(self):
|
|
xmin, xmax = self.axis.get_xlim()
|
|
ymin, ymax = self.axis.get_ylim()
|
|
self.world = '{:.3}, {:.3}, {:.3}, {:.3}'.format(xmin, ymin, xmax, ymax)
|
|
box = self.axis.get_position()
|
|
fx, fy = self.axis.figure.get_size_inches()
|
|
sx, sy = self.agr_figure.pagescale
|
|
offx, offy = self.agr_figure.offset
|
|
self.view = '{:.3}, {:.3}, {:.3}, {:.3}'.format(
|
|
(box.xmin + offx)*sx, (box.ymin + offy)*sy,
|
|
(box.xmax + offx)*sx, (box.ymax + offy)*sy
|
|
)
|
|
|
|
def get_title(self):
|
|
self.title = latex_to_xmgrace(self.axis.get_title())
|
|
|
|
def get_xyaxis(self):
|
|
self.xlabel = latex_to_xmgrace(self.axis.get_xlabel())
|
|
xpos = self.axis.xaxis.get_label_position()
|
|
self.xlabelpos = 'normal' if xpos == 'bottom' else 'opposite'
|
|
|
|
self.ylabel = latex_to_xmgrace(self.axis.get_ylabel())
|
|
ypos = self.axis.yaxis.get_label_position()
|
|
self.ylabelpos = 'normal' if ypos == 'left' else 'opposite'
|
|
|
|
xsc = self.axis.get_xscale()
|
|
self.xscale = 'Logarithmic' if 'log' in xsc else 'Normal'
|
|
xticks = self.axis.get_xticks()
|
|
if xsc == 'log':
|
|
self.xticks = (xticks[1:] / xticks[:-1]).mean()
|
|
else:
|
|
self.xticks = (xticks[1:] - xticks[:-1]).mean()
|
|
|
|
self.xticklabel = 'on' if any([x.get_visible() for x in self.axis.get_xticklabels()]) else 'off'
|
|
xpos = self.axis.xaxis.get_ticks_position()
|
|
if xpos == 'both':
|
|
self.xticklabelpos = 'both'
|
|
elif xpos == 'top':
|
|
self.xticklabelpos = 'opposite'
|
|
else:
|
|
self.xticklabelpos = 'normal'
|
|
|
|
xtlpos = set([x.get_position()[1] for x in self.axis.get_xticklabels() if x.get_visible()])
|
|
if len(xtlpos) == 0:
|
|
self.xticklabel = 'off'
|
|
self.xticklabelpos = 'normal'
|
|
elif len(xtlpos) == 1:
|
|
self.xticklabel = 'on'
|
|
self.xticklabelpos = 'opposite' if 1 in xtlpos else 'normal'
|
|
else:
|
|
self.xticklabel = 'on'
|
|
self.xticklabelpos = 'both'
|
|
|
|
ytlpos = set([x.get_position()[0] for x in self.axis.get_yticklabels() if x.get_visible()])
|
|
if len(ytlpos) == 0:
|
|
self.yticklabel = 'off'
|
|
self.yticklabelpos = 'normal'
|
|
elif len(ytlpos) == 1:
|
|
self.yticklabel = 'on'
|
|
self.yticklabelpos = 'opposite' if 1 in ytlpos else 'normal'
|
|
else:
|
|
self.yticklabel = 'on'
|
|
self.yticklabelpos = 'both'
|
|
|
|
ysc = self.axis.get_yscale()
|
|
self.yscale = 'Logarithmic' if 'log' in ysc else 'Normal'
|
|
yticks = self.axis.get_yticks()
|
|
if ysc == 'log':
|
|
self.yticks = (yticks[1:] / yticks[:-1]).mean()
|
|
else:
|
|
self.yticks = (yticks[1:] - yticks[:-1]).mean()
|
|
|
|
self.labelsize = self.axis.xaxis.get_label().get_fontsize() * self.agr_figure.fontscale
|
|
fs = self.axis.xaxis.get_ticklabels()[0].get_fontsize()
|
|
self.ticklabelsize = fs * self.agr_figure.fontscale
|
|
|
|
def get_legend(self):
|
|
leg = self.axis.get_legend()
|
|
if leg is None:
|
|
self.legend = 'off'
|
|
self.legend_pos = '0, 0'
|
|
self.legend_fontsize = 0
|
|
else:
|
|
self.legend = 'on'
|
|
for lbl, line in zip(leg.get_texts(), leg.get_lines()):
|
|
pass
|
|
pos = get_world_coords(leg)
|
|
self.legend_pos = '{:.3f}, {:.3f}'.format(*pos.diagonal())
|
|
self.legend_fontsize = leg.get_texts()[0].get_fontsize() * self.agr_figure.fontscale
|
|
|
|
def __init__(self, axis, agr_fig):
|
|
self.agr_figure = agr_fig
|
|
self.axis = axis
|
|
|
|
# run all get_ methods
|
|
for d in dir(self):
|
|
if d.startswith('get_'):
|
|
getattr(self, d)()
|
|
|
|
sanitize_strings(self.__dict__)
|
|
self.lines = {'s{}'.format(i): AgrLine(l, self) for i, l in enumerate(axis.lines)}
|
|
self.texts = [AgrText(t, self) for t in self.axis.texts]
|
|
|
|
def __str__(self):
|
|
o = self.fmt.format(**self.__dict__)
|
|
for k, l in self.lines.items():
|
|
o += textwrap.indent(str(l), prefix=k + ' ')
|
|
for txt in self.texts:
|
|
o += 'with string\n'
|
|
o += textwrap.indent(str(txt), prefix=' ')
|
|
return o
|
|
|
|
|
|
class AgrFigure:
|
|
dpi = 96
|
|
fontscale = 0.5
|
|
fmt = """@version 50125
|
|
@page size {page}
|
|
"""
|
|
data_fmt = "@target {target}\n{data}\n"""
|
|
|
|
def collect_data(self):
|
|
d = {}
|
|
for ia, ax in self.axes.items():
|
|
for il, ln in ax.lines.items():
|
|
d['{}.{}'.format(ia.upper(), il.upper())] = ln.data
|
|
return d
|
|
|
|
def get_figprops(self):
|
|
fx, fy = self.figure.get_size_inches()
|
|
scx, scy = (1 + self.offset)
|
|
# scy = (1 + self.offset_vertical)
|
|
self.page = '{}, {}'.format(int(scx * self.dpi * fx), int(scy * self.dpi * fy))
|
|
self.fontscale = AgrFigure.fontscale / min(fx * scx, fy * scy)
|
|
|
|
self.pagescale = np.array([fx, fy]) / min(fx * scx, fy * scy)
|
|
|
|
def __init__(self, figure, offset_horizontal=0, offset_vertical=0, convert_latex=True):
|
|
tex2grace.do_latex_conversion = convert_latex
|
|
self.figure = figure
|
|
self.offset = np.array([offset_horizontal, offset_vertical])
|
|
# make sure to draw the figure...
|
|
canv = backend_agg.FigureCanvasAgg(figure)
|
|
canv.draw()
|
|
|
|
# run all get_ methods
|
|
for d in dir(self):
|
|
if d.startswith('get_'):
|
|
getattr(self, d)()
|
|
sanitize_strings(self.__dict__)
|
|
|
|
self.axes = {'g{}'.format(i): AgrAxis(ax, self) for i, ax in enumerate(self.figure.axes)}
|
|
|
|
def __str__(self):
|
|
o = self.fmt.format(**self.__dict__)
|
|
|
|
for i, col in enumerate(AgrLine.colors):
|
|
# in matplotlib-1.5 to_rgb can not handle 'none', this was fixed in 2.0
|
|
rgb = [int(x * 255) for x in AgrLine.cc.to_rgba(col)[:3]]
|
|
o += '@map color {i} to ({rgb[0]}, {rgb[1]}, {rgb[2]}), "{col}"\n'.format(i=i, rgb=rgb, col=col)
|
|
|
|
for k, ax in self.axes.items():
|
|
o += textwrap.indent("on\nhidden false\ntype XY\nstacked false\n", prefix='@{} '.format(k))
|
|
o += '@with {}\n'.format(k)
|
|
o += textwrap.indent(str(ax), prefix='@ ')
|
|
|
|
for k, d in self.collect_data().items():
|
|
o += self.data_fmt.format(target=k, data=d)
|
|
|
|
return o
|
|
|
|
|
|
def saveagr(fname, figure=None, offset_x=0, offset_y=0, convert_latex=True):
|
|
"""
|
|
Save figure as xmgrace plot.
|
|
|
|
If no figure is provided, this will save the current figure.
|
|
|
|
Args:
|
|
fname: Filename of the agr plot
|
|
figure (opt.): Matplotlib figure to save, if None gcf() is used.
|
|
offset_x, offset_y (opt.): Add an offest in x or y direction to the xmgrace plot.
|
|
convert_latex (opt.): If latex strings will be converted to xmgrace.
|
|
"""
|
|
if figure is None:
|
|
figure = plt.gcf()
|
|
with open(fname, 'w') as f:
|
|
af = AgrFigure(figure, offset_horizontal=offset_x, offset_vertical=offset_y,
|
|
convert_latex=convert_latex)
|
|
f.write(str(af))
|