Butterfly/Tornado Chart

I was trying to find the best solution for making butterfly charts (find an example here: Butterfly Chart | Data Viz Project), and given there was no good built-in support amongst the usual graph libraries, I decided to give it a whirl with bokeh.

I just wanted to share my solution so if anyone decides to make something similar in the future, there is at least an example in bokeh (I couldn’t find examples when I was searching).

Though the below code is a bit long, the key points:

  • You need to make the data going to the left negative so it appears to head to the left.
  • Likewise, for the x axis ticks you need to make them abs values so they are correct
  • From what I’ve gathered, bokeh doesn’t allow “moving” the y-axis, so hide it and put custom labels in the “middle” of the graph

For the below example, I used a DataFrame that looks like this:

Gender Age_Group Count
M “10-14” 40
F “10-14” 20
M “15-19” 13
F “15-19” 25

import pandas as pd

from bokeh.plotting import figure
from bokeh.models import LabelSet
from bokeh.models.formatters import CustomJSTickFormatter
from bokeh.models.sources import ColumnDataSource

ABS_COL_POSTFIX='_Abs'

def create_butterfly_graph(
        df:pd.DataFrame,
        l_col_val='男', r_col_val='女',
        y_col_name='Age_Group', x_col_name='Count', group_col_name='Gender',
        width=250, height=250,
        l_color='lightblue', r_color='pink',
        ):
    '''
    Creates a butterfly graph
    DF is structured as so:
    x_col_name | y_col_name | group_col_name | ...and other columns
    group_col_name is either l_col_val or r_col_val
    '''
    pivot_df = _helper_create_pivot_df(df)
    source = ColumnDataSource(pivot_df)
    max_x = pivot_df[[r_col_val, l_col_val+ABS_COL_POSTFIX]].max().max() + 10 # Max x/y to fit labels
    
    p = figure(
        y_range=list(reversed(pivot_df[y_col_name])), # We do this to get descending (smallest up top)
        x_axis_location="below", 
        x_range = (-max_x, max_x),
        width=width, height=height,
        toolbar_location=None
    )
    
    _helper_create_hbars(p, source,l_col_val, r_col_val, y_col_name, l_color, r_color)
    _helper_create_labels(p, source)
    _helper_set_axes(p)
    
    return p
    
    
def _helper_create_hbars(
        p, source, 
        l_col_val, r_col_val, 
        y_col_name,
        l_color, r_color):
    p.hbar_stack([l_col_val], y=y_col_name, source=source, height=0.8, color=[l_color])
    p.hbar_stack([r_col_val], y=y_col_name, source=source, height=0.8, color=[r_color])

def _helper_create_labels(p, source, l_col_val, r_col_val, y_col_name):
    # Labels at the origin (y-axis replacement)
    labels_group = LabelSet(x=0, y = y_col_name, text = y_col_name, level='glyph',
                            x_offset=0, y_offset=-2,
                            source=source,
                            text_align='center', text_font_size='7pt', text_color='gray', text_alpha=0.9, 
                            background_fill_color='white', background_fill_alpha=0.6
                    )

    # Labels for the left side of the butterfly
    labels_m = LabelSet(x = l_col_val, y = y_col_name, 
                        text=l_col_val+ABS_COL_POSTFIX, level='glyph', text_align='center', text_color='gray', text_font_size='8pt',
                        x_offset=-10, y_offset=-2, source=source
                        )
    # Labels for the right side of the butterfly
    labels_f = LabelSet(x = r_col_val, y = y_col_name, 
                        text=r_col_val, level='glyph', text_align='center', text_color='gray', text_font_size='8pt',
                        x_offset=10, y_offset=-2, source=source
                        )
    
    p.add_layout(labels_group)
    p.add_layout(labels_m)
    p.add_layout(labels_f)

def _helper_create_pivot_df(
        df:pd.DataFrame,
        l_col_val, r_col_val,
        x_col_name, y_col_name,
        group_col_name
        )->pd.DataFrame:
    pivot_df = df.pivot(index=y_col_name, columns=group_col_name, values=x_col_name).reset_index()
    if l_col_val not in pivot_df:
        pivot_df[l_col_val]=0
    if r_col_val not in pivot_df:
        pivot_df[r_col_val]=0
    
    # This is used to label the bars
    pivot_df[l_col_val + ABS_COL_POSTFIX] = pivot_df[l_col_val]
    
    # We do this so the left bars go to the left
    pivot_df[l_col_val] = pivot_df[l_col_val]*-1
    return pivot_df

def _helper_set_axes(p):
    p.xaxis.formatter = CustomJSTickFormatter(code="""return Math.abs(tick);""")
    
    p.yaxis.visible=False
    
    p.xgrid.grid_line_color = None
    p.ygrid.grid_line_color = None