import altair from altair import Config, Color, Shape, Column, Row, Encoding, Scale, Axis from random import randint import os import logging import matplotlib.pyplot as plt from .tud import nominal_colors, full_colors def filter_nulltime_json(data): if 'time' in data: data = data[data.time > 0] return altair.pipe(data, altair.to_json) altair.renderers.enable('notebook') altair.data_transformers.register('json_logtime', filter_nulltime_json) altair.data_transformers.enable('json_logtime') def my_theme(*args, **kwargs): return { 'range': { 'ordinal': altair.VgScheme('viridis'), 'ramp': {'scheme': 'viridis'} } } altair.themes.register('my-theme', my_theme) altair.themes.enable('my-theme') class BaseMixin(Encoding): def __init__(self, *args, **kwargs): kwargs.setdefault('scale', altair.Scale(zero=False)) super().__init__(*args, **kwargs) class LogMixin(BaseMixin): def __init__(self, *args, **kwargs): kwargs['scale'] = altair.Scale(type='log') super().__init__(*args, **kwargs) class X(altair.X, BaseMixin): pass class Y(altair.Y, BaseMixin): pass class LogX(altair.X, LogMixin): pass class LogY(altair.Y, LogMixin): pass class DataHandler(altair.Data): def __init__(self, df): self._filename = '.altair.json' with open(self._filename, 'w') as f: f.write(df.to_json()) super().__init__(url=self._filename) class Chart(altair.Chart): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def encode(self, *args, color=None, **kwargs): if isinstance(color, str): if color.endswith(':F'): field = color[:-2] color = color.replace(':F', ':N') self.configure_scale(nominalColorRange=full_colors(len(set(self._data[field])))) for arg in args: if isinstance(arg, altair.X): kwargs['x'] = arg elif isinstance(arg, altair.Y): kwargs['y'] = arg return super().encode(color=color, **kwargs) def to_mpl(self): d = self.to_dict() fmt = 'o' if d.get('mark', 'point') is 'point' else '-' def encode(data, encoding, **kwargs): logging.debug(str(kwargs)) if 'column' in encoding: channel = encoding.pop('column') ncols = len(data[channel.get('field')].unique()) for col, (column, df) in enumerate(data.groupby(channel.get('field'))): ax = plt.gca() if col > 0 else None plt.subplot(kwargs.get('nrows', 1), ncols, col + 1, sharey=ax).set_title(column) encode(df, encoding.copy(), secondary_column=col > 0, **kwargs.copy()) elif 'color' in encoding: channel = encoding.pop('color') if channel.get('type') == 'quantitative': colors = full_colors(len(data[channel.get('field')].unique())) else: colors = nominal_colors['b'] while len(colors) < len(data[channel.get('field')].unique()): colors *= 2 for color, (column, df) in zip(colors, data.groupby(channel.get('field'))): if 'label' in kwargs: label = kwargs.pop('label') + ', {}'.format(column) else: label = str(column) encode(df, encoding.copy(), color=color, label=label, **kwargs.copy()) elif 'shape' in encoding: channel = encoding.pop('shape') markers = ['h', 'v', 'o', 's', '^', 'D', '<', '>'] while len(markers) < len(data[channel.get('field')].unique()): markers *= 2 logging.debug(str(data[channel.get('field')].unique())) for marker, (column, df) in zip(markers, data.groupby(channel.get('field'))): if 'label' in kwargs: label = kwargs.pop('label') + ', {}'.format(column) else: label = str(column) encode(df, encoding.copy(), marker=marker, label=label, **kwargs.copy()) else: x_field = encoding.get('x').get('field') y_field = encoding.get('y').get('field') plt.xlabel(x_field) if not kwargs.pop('secondary_column', False): plt.ylabel(y_field) else: plt.tick_params(axis='y', which='both', labelleft='off', labelright='off') if 'scale' in encoding.get('x'): plt.xscale(encoding['x']['scale'].get('type', 'linear')) if 'scale' in encoding.get('y'): plt.yscale(encoding['y']['scale'].get('type', 'linear')) plt.plot(data[x_field], data[y_field], fmt, **kwargs) plt.legend(loc='best', fontsize='small') encode(self._data, d.get('encoding')) plt.tight_layout(pad=0.5) class Arrhenius(Chart): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # self.transform_data(calculate=[Formula(field='1000K/T', expr='1000/datum.T')]) self.data['1000 K / T'] = 1000 / self.data['T'] self.encode(x=X('1000 K / T'))