Interactive tree/hierarchy diagram

Hello all,

I am looking to create a plot that looks like so:

I have come across separate ideas for this sort of thing, but often just showing “exploded” sort of network graphs where the lines and nodes aren’t so neatly defined.

Here is a good description of how to get positions for a tree diagram using network:
Can one get hierarchical graphs from networkx with python 3? - Stack Overflow

import networkx as nx
import random

    
def hierarchy_pos(G, root=None, width=1., vert_gap = 0.2, vert_loc = 0, xcenter = 0.5):

    '''
    From Joel's answer at https://stackoverflow.com/a/29597209/2966723.  
    Licensed under Creative Commons Attribution-Share Alike 
    
    If the graph is a tree this will return the positions to plot this in a 
    hierarchical layout.
    
    G: the graph (must be a tree)
    
    root: the root node of current branch 
    - if the tree is directed and this is not given, 
      the root will be found and used
    - if the tree is directed and this is given, then 
      the positions will be just for the descendants of this node.
    - if the tree is undirected and not given, 
      then a random choice will be used.
    
    width: horizontal space allocated for this branch - avoids overlap with other branches
    
    vert_gap: gap between levels of hierarchy
    
    vert_loc: vertical location of root
    
    xcenter: horizontal location of root
    '''
    if not nx.is_tree(G):
        raise TypeError('cannot use hierarchy_pos on a graph that is not a tree')

    if root is None:
        if isinstance(G, nx.DiGraph):
            root = next(iter(nx.topological_sort(G)))  #allows back compatibility with nx version 1.11
        else:
            root = random.choice(list(G.nodes))

    def _hierarchy_pos(G, root, width=1., vert_gap = 0.2, vert_loc = 0, xcenter = 0.5, pos = None, parent = None):
        '''
        see hierarchy_pos docstring for most arguments

        pos: a dict saying where all nodes go if they have been assigned
        parent: parent of this branch. - only affects it if non-directed

        '''
    
        if pos is None:
            pos = {root:(xcenter,vert_loc)}
        else:
            pos[root] = (xcenter, vert_loc)
        children = list(G.neighbors(root))
        if not isinstance(G, nx.DiGraph) and parent is not None:
            children.remove(parent)  
        if len(children)!=0:
            dx = width/len(children) 
            nextx = xcenter - width/2 - dx/2
            for child in children:
                nextx += dx
                pos = _hierarchy_pos(G,child, width = dx, vert_gap = vert_gap, 
                                    vert_loc = vert_loc-vert_gap, xcenter=nextx,
                                    pos=pos, parent = root)
        return pos

            
    return _hierarchy_pos(G, root, width, vert_gap, vert_loc, xcenter)

How can this be used with bokeh to create an interactive plot of this, I need to be able to click on a node for example and produce a click event. Also hover to see extra information about the node.

Here is a sort of relevant graph with bokeh but it doesn’t look like a hierarchy tree
Make an Interactive Network Visualization with Bokeh — Introduction to Cultural Analytics & Python (melaniewalsh.github.io)

Appreciate any suggestions.

TBH I am not sure this is a good fit for Bokeh. It’s not impossible, for sure, but you’re going to have to lay out everything manually. Additionally support for “vectorized” arrows (or lines with arrowheads) is not added yet, so you would need to make a call to add a separate annotation for every arrow.

1 Like

I’ve looked at cooking up something like this with bokeh and came to the same conclusion… but I didn’t wanna say it couldn’t be done :smiley:

@CTPassion Here is one suggestion. I am using DiGraph in networkx to create a directed graph from a dataframe where I have defined the parent and child nodes. I use your function hierarchy_pos to get the positions (one can also use nx.nx_agraph.pygraphviz_layout but it requires some extra imports and graphviz installed).
I then use the edges and positions to calculate new edges and nodes that will give vertical and horizontal lines (can probably be optimized).

I have added the arrows using inverted_triangle; had to do a bit of calculations with respect to positioning the arrows since height of the boxes are in screen units. It also means that if you zoom the arrows will not stay at the correct location

I have added TapTool as a click callback event (JS).

Edit: add the following to have the arrows stay at correct location when zooming with mouse wheel:

from bokeh.events import MouseWheel
cb = CustomJS(
    args= {'source': src, 'plot': plot, 'yrng': plot.y_range},
    code = '''
    const canvas_dy = plot.inner_height;
    const rng_y0 = yrng.start;
    const rng_y1 = yrng.end;

    const y_offset = 20;
    const px_resol = (rng_y1-rng_y0)/canvas_dy;
    const data = source.data;
    for (let i=1; i < data['arrow_y'].length; i++) {
        data['arrow_y'][i] = data['y'][i] + y_offset*px_resol;
    }

    source.data = data;
    source.change.emit();
    ''')
plot.js_on_event(MouseWheel, cb)

import pandas as pd
import numpy as np
import networkx as nx
from bokeh.io import save, output_file
from bokeh.models import Circle, ColumnDataSource, MultiLine
from bokeh.models import HoverTool, TapTool, CustomJS
from bokeh.plotting import figure, from_networkx

output_file("plot_hierarchy.html")

parent_childs = {
    'President': ['Supte1', 'Supte2', 'Supte3'],
    'Supte1': ['Ger1', 'Ger2']
}
levels = ['President', 'Supte1']
tree_df = pd.DataFrame(
    [(l, c) for l in levels for c in parent_childs[l]],
    columns = ['parent', 'child']
    )

def hierarchy_pos(G, root=None, width=1., vert_gap = 0.2, vert_loc = 0, xcenter = 0.5):

    '''
    From Joel's answer at https://stackoverflow.com/a/29597209/2966723.  
    Licensed under Creative Commons Attribution-Share Alike 
    
    If the graph is a tree this will return the positions to plot this in a 
    hierarchical layout.
    
    G: the graph (must be a tree)
    
    root: the root node of current branch 
    - if the tree is directed and this is not given, 
      the root will be found and used
    - if the tree is directed and this is given, then 
      the positions will be just for the descendants of this node.
    - if the tree is undirected and not given, 
      then a random choice will be used.
    
    width: horizontal space allocated for this branch - avoids overlap with other branches
    
    vert_gap: gap between levels of hierarchy
    
    vert_loc: vertical location of root
    
    xcenter: horizontal location of root
    '''
    if not nx.is_tree(G):
        raise TypeError('cannot use hierarchy_pos on a graph that is not a tree')

    if root is None:
        if isinstance(G, nx.DiGraph):
            root = next(iter(nx.topological_sort(G)))  #allows back compatibility with nx version 1.11
        else:
            root = random.choice(list(G.nodes))

    def _hierarchy_pos(G, root, width=1., vert_gap = 0.2, vert_loc = 0, xcenter = 0.5, pos = None, parent = None):
        '''
        see hierarchy_pos docstring for most arguments

        pos: a dict saying where all nodes go if they have been assigned
        parent: parent of this branch. - only affects it if non-directed

        '''
    
        if pos is None:
            pos = {root:(xcenter,vert_loc)}
        else:
            pos[root] = (xcenter, vert_loc)
        children = list(G.neighbors(root))
        if not isinstance(G, nx.DiGraph) and parent is not None:
            children.remove(parent)  
        if len(children)!=0:
            dx = width/len(children) 
            nextx = xcenter - width/2 - dx/2
            for child in children:
                nextx += dx
                pos = _hierarchy_pos(G,child, width = dx, vert_gap = vert_gap, 
                                    vert_loc = vert_loc-vert_gap, xcenter=nextx,
                                    pos=pos, parent = root)
        return pos

            
    return _hierarchy_pos(G, root, width, vert_gap, vert_loc, xcenter)

def horiz_vert_edges(G, pos):
    '''
    Calculate vertical and horizontal positions and create new edges
    '''
    new_pos = {}
    new_edges = []

    i = max(pos.keys())
    for (n1, n2) in G.edges:
        if pos[n1][0] == pos[n2][0]:
            new_pos[n1] = pos[n1]
            new_pos[n2] = pos[n2]
            new_edges.append((n1, n2))
            continue

        i += 1
        y = (pos[n1][1]+pos[n2][1])/2

        new_edges.append((n1, i))
        new_pos[n1] = pos[n1]
        new_pos[i] = (pos[n1][0], y)

        j = i + 1
        new_edges.append((i, j))
        new_pos[j] = (pos[n2][0], y)
        new_pos[n2] = pos[n2]

        new_edges.append((j, n2))

        i += 1

    return new_pos, new_edges

# Generate directed graph
G = nx.from_pandas_edgelist(
    tree_df , 'parent', 'child', create_using = nx.DiGraph
)

#roots = [n for n,d in G.in_degree() if d==0]
#end_nodes = [x for x in G.nodes() if G.out_degree(x)==0 and G.in_degree(x)==1]

#print(G.edges)
#print(roots)
#print(end_nodes)

# Bokeh requires integer labels
G = nx.convert_node_labels_to_integers(G, label_attribute = 'node')
pos = hierarchy_pos(G)

# create new edges and positions for vertical and horizontal layout
new_pos, new_edges = horiz_vert_edges(G, pos)

# for drwaing layout use new edges and positions
GN = nx.DiGraph()
GN.add_edges_from(new_edges)

# create CDS for labels
df_cds = pd.DataFrame.from_dict(
    pos, orient = 'index', columns = ['x','y']
).reset_index()


# screen/data units resolution from CustomJS.
# needed to position arrows at top of box, somewhat cumbersome
px_resol = (0.04+0.44)/490

node_int_mapping = {n: G.nodes[n]["node"] for n in pos}
df_cds['node'] = df_cds['index'].map(node_int_mapping)
df_cds['arrow_y'] = df_cds['y'] + 20*px_resol
df_cds.loc[df_cds['node'] == 'President', ['arrow_y']] = np.nan

src = ColumnDataSource(df_cds)

plot = figure(
    width = 800,
    height = 500,
    tools = "pan,wheel_zoom,box_zoom,save,reset",
    active_scroll = 'wheel_zoom'
)

plot.x_range.range_padding = 0.2
plot.y_range.range_padding = plot.x_range.range_padding

# create Bokeh network graph 
network_graph = from_networkx(GN, new_pos)

# do not show automatic circles (size = 0)
network_graph.node_renderer.glyph = Circle(size=0, fill_color='skyblue')

#Set edge opacity and width
network_graph.edge_renderer.glyph = MultiLine(line_alpha=1, line_width=2)

#Add network graph to the plot
plot.renderers.append(network_graph)

# one issue with using a mix of screen and data units is that zooming 
# ruins the layout. 
# I have kept the box in screen units since the text is also in screen units
# But that means the arrows are not positioned correct when zooming.
r_rect = plot.rect(
    x = 'x', 
    y = 'y', 
    width = 70,
    width_units = 'screen',
    height = 30,
    height_units = 'screen',
    fill_color = 'white',
    line_color = 'navy',
    line_width = 2,
    border_radius = 5,
    source = src
)

plot.text(
    x = 'x', 
    y = 'y', 
    text_baseline = 'middle',
    text_align = 'center',
    text_font_size = '9pt',
    text = 'node',
    source = src
)

plot.inverted_triangle(
    x = 'x',
    y = 'arrow_y',
    color = 'black',
    size = 8,
    source = src
    )
hover = HoverTool(tooltips=[('Node', '@node')], renderers = [r_rect])
plot.add_tools(hover)

code = '''
const idx = source.inspected.indices;
console.log('Item: ' + source.data['node'][idx]);

const canvas_dy = plot.inner_height;
console.log(canvas_dy);

const rng_y0 = yrng.start;
const rng_y1 = yrng.end;
console.log(rng_y0);
console.log(rng_y1);
'''
callback = CustomJS(args = {'source': src, 'plot': plot, 'yrng': plot.y_range}, code=code)
plot.add_tools(
    TapTool(
        callback = callback,
        behavior ='inspect',
        renderers = [r_rect]))


plot.toolbar.autohide = True
plot.axis.visible = False
plot.grid.visible = False
plot.outline_line_color = None

save(plot)

Thanks for the responses. I had also found that method created for network on stack overflow and made my own solution, I incorporated some of yours too to work in my situation. I did not need the arrows, so left them as they are buggy anyway like you said.

I also added a block of code to remove the ‘0’ node if it existed. This is because I wanted my tree to start with multiple nodes at the top, but the hierarchy_pos provided by you/stack overflow would only work with one starting node to make it a tree. To get around this, I keep the starting 0 node until the positions are calculated, then just remove it. I also make sure the 0 node sits at the same y level in the hierarchy positions so the graph is not shifted down by the 0 node :slight_smile:

    def hierarchy_pos(
        graph, root=None, width=1.0, vert_gap=0.2, vert_loc=0, x_centre=0.0
    ) -> dict:
        if not nx.is_tree(graph):
            raise TypeError("Cannot use hierarchy_pos on a graph that is not a tree")

        if root is None:
            root = next(iter(nx.topological_sort(graph)))

        def _hierarchy_pos(
            graph,
            root,
            width=1.0,
            vert_gap=0.2,
            vert_loc=0.0,
            x_centre=0.5,
            pos=None,
            parent=None,
        ) -> dict:
            if pos is None:
                pos = {root: (x_centre, vert_loc)}
            else:
                pos[root] = (x_centre, vert_loc)

            children = list(graph.neighbors(root))

            if not isinstance(graph, nx.DiGraph) and parent is not None:
                children.remove(parent)

            if children:
                child_width = width / len(children)
                next_x = x_centre - width / 2 - child_width / 2
                next_y = vert_loc - vert_gap if root != 0 else vert_loc
                for child in children:
                    next_x += child_width
                    pos = _hierarchy_pos(
                        graph,
                        child,
                        width=child_width,
                        vert_gap=vert_gap,
                        vert_loc=next_y,
                        x_centre=next_x,
                        pos=pos,
                        parent=root,
                    )
            return pos

        return _hierarchy_pos(graph, root, width, vert_gap, vert_loc, x_centre)

# Create graph
        graph = nx.from_pandas_edgelist(
            dataframe,
            "PARENT_NODE_NUMBER",
            "NODE_NUMBER",
            create_using=nx.DiGraph(),
        )

        # Collect data into nodes
        node_properties = self.determine_node_properties(selected_df)

        positions = self.hierarchy_pos(graph)

        new_pos, new_edges = self.horizontal_vertical_edges(graph, positions)
        # For drawing right-angled lines use new edges and positions
        new_graph = nx.DiGraph()
        new_graph.add_edges_from(new_edges)

        # The first node may be a parent node that is not a tag, this is added by nx
        # as the root because the root tag has a '0' parent, remove it:
        if 0 in new_pos:
            del new_pos[0]
            new_graph.remove_node(0)
            graph.remove_node(0)
            del positions[0]

        network_graph = from_networkx(new_graph, new_pos)
        network_graph.inspection_policy = NodesOnly()

        # Style nodes
        network_graph.node_renderer.glyph = Rect(
            height=0.12,
            width=0.12,
            fill_color="fill_colour",
            border_radius=6,
            line_color="line_colour",
            line_width="line_width",
            dilate=True,
        )
        network_graph.node_renderer.selection_glyph = Rect(
            height=0.14,
            width=0.14,
            fill_color="fill_colour",
            border_radius=6,
            line_color="line_colour",
            line_width="selected_line_width",
            dilate=True,
        )
        network_graph.node_renderer.hover_glyph = Rect(
            height=0.12,
            width=0.12,
            fill_color="fill_colour",
            border_radius=6,
            line_color="line_colour",
            line_width="hover_line_width",
            dilate=True,
        )
        network_graph.node_renderer.data_source.data = {
            "index": list(graph.nodes),
            "name": [node_properties["node_names"][node] for node in graph.nodes],
            "value": [node_properties["node_values"][node] for node in graph.nodes],
            "label": [node_properties["node_labels"][node] for node in graph.nodes],
            "fill_colour": [
                node_properties["node_colours"][node] for node in graph.nodes
            ],
            "line_colour": [
                node_properties["line_colours"][node] for node in graph.nodes
            ],
            "line_width": [
                node_properties["line_widths"][node] for node in graph.nodes
            ],
            "hover_line_width": [
                node_properties["hover_line_widths"][node] for node in graph.nodes
            ],
            "selected_line_width": [
                node_properties["selected_line_widths"][node] for node in graph.nodes
            ],
        }

        # Style connections
        network_graph.edge_renderer.glyph = MultiLine(line_width=2, line_color="black")
        plot.renderers.append(network_graph)

        # Add text labels
        df_cds = pd.DataFrame.from_dict(
            positions, orient="index", columns=["x", "y"]
        ).reset_index()
        df_cds["label"] = df_cds["index"].map(node_properties["node_labels"])
        text_src = ColumnDataSource(df_cds)
        text_glyph = Text(
            x="x",
            y="y",
            text="label",
            text_font_size="12pt",
            text_baseline="middle",
            text_align="center",
            text_color="white",
        )
        plot.add_glyph(text_src, text_glyph)

        # Hover
        hover = HoverTool(
            tooltips=f"""<div>
                            <span style="font-family: {styles.FONT_FAMILY};font-size: 18px;"><b>Name:</b> @name</span><br>
                            <span style="font-family: {styles.FONT_FAMILY};font-size: 18px;"><b>Value:</b> @value</span>
                         </div>""",
            muted_policy="ignore",
            point_policy="follow_mouse",
            renderers=[network_graph.node_renderer],
            attachment="left",
        )
        plot.add_tools(hover, TapTool())
        plot.on_event(Tap, self.plot_callback)

        return plot

With a 0 node:

With multiple “parent” nodes:

1 Like

@CTPassion Glad to see you found a solution and thanks for sharing. I learned a bit about proper updating network_graph properties :slightly_smiling_face:

I would not say that the arrows are buggy, it was just a hacky way for me to add arrows when I in my example mix screen units (width/height of rectangles) and data units of location data. I wanted to use screen units for the rectangles since the font is also screen units. But whether zooming is necessary on a tree diagram is another story :slight_smile:

2 Likes

FWIW a proper vectorized glyph for arrows is hopefully at least on the horizon (in general, currently moving all existing annotation classes to become standard glyphs, so that “annotation” is viewed just as something you do, rather than some set of classes)

2 Likes

Separate point - have you figured how to size the node glyphs so that they don’t overlap when the tree becomes large?
e.g. repeating the tree makes them overlap

EDIT
I have answered my own question, please use this version to calculate positions for a more general approach / if you are getting overlaps. The key change is to make the horizontal spacing between nodes proportional to the number of descendants they have. This way, nodes with more children will have more space allocated to them, reducing the chance of overlap.

    def hierarchy_pos(
            graph, root=None, width=1.0, vert_gap=0.2, vert_loc=0, x_centre=0.0, x_offset=0.5
    ) -> dict:
        """
        Return positions for a hierarchical layout of the input tree-like graph.

        Parameters
        -------
            graph   : networkx graph
                The input graph (must be a tree)
            root    : int
                The root node of the current branch
            width   : float
                Horizontal space allocated for this branch - avoids overlap with other branches
            vert_gap    : float
                Gap between levels of hierarchy
            vert_loc    : float
                Vertical location of root
            x_centre     : float
                Horizontal location of root
            x_offset    : float
                Additional offset parameter for dynamic spacing

        Returns
        -------
            dict
                A dictionary containing positions of nodes in the hierarchical layout
        """
        if not nx.is_tree(graph):
            raise TypeError("Cannot use hierarchy_pos on a graph that is not a tree")

        if root is None:
            root = next(iter(nx.topological_sort(graph)))

        def _hierarchy_pos(
                graph,
                root,
                width=width,
                vert_gap=vert_gap,
                vert_loc=vert_loc,
                x_centre=x_centre,
                pos=None,
                parent=None,
                x_offset=x_offset
        ) -> dict:
            if pos is None:
                pos = {root: (x_centre, vert_loc)}
            else:
                pos[root] = (x_centre, vert_loc)

            children = list(graph.neighbors(root))
            if not isinstance(graph, nx.DiGraph) and parent is not None:
                children.remove(parent)  # remove the parent from the children list if it's there

            # Calculate the total number of descendants for each child (for dynamic width allocation)
            subtree_sizes = {child: len(nx.descendants(graph, child)) + 1 for child in children}
            total_descendants = sum(subtree_sizes.values())

            if total_descendants > 0:
                relative_widths = {child: size / total_descendants for child, size in subtree_sizes.items()}
            else:
                relative_widths = {child: 1 for child in children}

            # Keep track of the x position for the next child
            next_x = x_centre - (width * x_offset)

            for child in children:
                child_width = width * relative_widths[child]
                # Centre the child in its allocated space
                child_x_centre = next_x + (child_width / 2)
                next_x += child_width  # Increment x position for the next child

                # Recursively call to place each child's subtree
                pos = _hierarchy_pos(
                    graph,
                    child,
                    width=child_width,
                    vert_gap=vert_gap,
                    vert_loc=vert_loc - vert_gap,
                    x_centre=child_x_centre,
                    pos=pos,
                    parent=root,
                    x_offset=x_offset  # Pass the offset for recursive calls
                )

            return pos

        return _hierarchy_pos(graph, root, width, vert_gap, vert_loc, x_centre)


and another repeat: