Set intersection of two+ lasso tools

Hi All - I have been learning bokeh (and javascript, to boot!) for a few projects in lab.

My goal for this figure is to plot two or more scatter plots, each of which will have a subset of “IDs” that are common to every plot. I want to be able to:

  • Lasso clouds of points in every plot.
  • Do a set intersection on the IDs of the selected points.
  • Plot the same data in another figure below, but with the intersection points highlighted.

I want to use javascript callbacks so that I can run this natively in a browser without a bokeh server.

Below is a minimum working example of what I have done so far. The way I have tried to implement this is to make a merged data source, and to use a javascript callback that, on update with the lasso tool, looks at all the selected points in the independent plots and counts how many times they have occurred. If they have occured the same number of times as there are plots, it pushes that as a selected ID to the merged dataframe (and shouldn’t it highlight rows in the table?). I chose this implementation because set intersection apparently isn’t natively supported in javascript in all browsers (Set.prototype.intersection() - JavaScript | MDN), and I know that each ID occurs only once in each plot.

All the plots are there, and I added a table that I think should highlight once the values in the table are selected. I did not yet write any code to highlight the dots in the lower plots based on what is highlighted in the table.

import pandas as pd
import numpy as np
from bokeh.plotting import figure, show
from bokeh.models   import ColumnDataSource, DataTable, TableColumn, CustomJS
from bokeh.layouts  import row, column


dfs = []
data1 = {"ID": ['ID1', 'ID2', 'ID3', 'ID5', 'ID6', 'ID7',    'ID9',     'ID10'],
         "1_x": np.random.randint(10, 20, 8),
         "1_y": np.random.randint(10, 20, 8)}
df1 = pd.DataFrame(data1, index=data1['ID']).drop(columns=['ID'])
dfs.append(df1)

data2 = {"ID": ['ID1',     'ID3', 'ID4', 'ID5', 'ID6', 'ID7', 'ID8', 'ID10'],
         "2_x": np.random.randint(10, 20, 8),
         "2_y": np.random.randint(10, 20, 8)}
df2 = pd.DataFrame(data2, index=data2['ID']).drop(columns=['ID'])
dfs.append(df2)

dfm = pd.merge(df1, df2, left_index=True, right_index=True, how='outer')
# yields something like this:
#        1_x   1_y   2_x   2_y
# ID1   19.0  10.0  18.0  10.0
# ID10  14.0  14.0  11.0  15.0
# ID2   17.0  12.0   NaN   NaN
# ID3   19.0  18.0  11.0  11.0
# ID4    NaN   NaN  11.0  19.0
# ID5   14.0  10.0  13.0  10.0
# ID6   10.0  15.0  14.0  17.0
# ID7   15.0  14.0  11.0  19.0
# ID8    NaN   NaN  16.0  15.0
# ID9   11.0  10.0   NaN   NaN
merged_dict = {col: dfm[col].values for col in dfm.columns}
merged_dict['ID'] = dfm.index
merge_source = ColumnDataSource(data=merged_dict)
columns = [TableColumn(field=x, title=x) for x in dfm.columns] + [TableColumn(field='ID', title='ID')]
data_table = DataTable(source=merge_source, columns = columns, editable=True, width=400, height = 400)

figures      = [] # these figures are for the top - they have their own, independent data sources
figures2     = [] # these figures are for the bottom and draw from the merged data source
bokeh_frames = []
for i in range(len(dfs)):
    # Plot the data and save the html file
    figures.append(
        figure(width  = 200,
               height = 200,
               tools  = ["lasso_select", "reset"],
               title  = f"Plot {i}"))
    figures2.append(
        figure(width  = 200,
               height = 200,
               tools = [],
               title  = f"Plot {i} - intersection"))
    bokeh_frames.append(
        ColumnDataSource(data=dict(ix = list(dfs[i].index)   ,
                                   x  = list(dfs[i][f"{i+1}_x"]),
                                   y  = list(dfs[i][f"{i+1}_y"]),
                                   )))

    figures[-1 ].circle('x', 'y',
                       size=8, source = bokeh_frames[-1], selection_color="firebrick")
    figures2[-1].circle(f"{i+1}_x", f"{i+1}_y",
                       size=8, source = merge_source,     selection_color="firebrick")

# code to update the table as highlighted
for i in range(len(bokeh_frames)):
    thisdict = {f"s{j+1}": bokeh_frames[j] for j in range(len(bokeh_frames))}
    thisdict["sMerge"] = merge_source
    thisdict["num_plots"] = len(bokeh_frames)
    javascript_array_string = "["
    for j in range(len(bokeh_frames)-1):
        javascript_array_string += f"s{j+1}, "
    javascript_array_string += f"s{len(bokeh_frames)}]"
    bokeh_frames[i].selected.js_on_change(
        "indices",
        CustomJS(
            args=thisdict,
            code=f"""
                // clear out previously selected values
                sMerge.selected.indices = []
                // set up counts for the indices
                let counts = {{}};
                // store the selected values of each plot.
                // This uses a funny combination of python to drop a variable number of objects into the javascript array code
                var selected_lists = {javascript_array_string}
                for (var i = 0; i < selected_lists.length; i++) {{
                    // for each table, go through the selected indices, get the index value, and increment the count
                    for (var j = 0; j < selected_lists[i].selected.indices.length; j++) {{
                        rowix = selected_lists[i].selected.indices[j]
                        // get the value in the 'ID' column of selected_lists[i]
                        var ID = selected_lists[i].data["ID"][rowix]
                        if (ID in counts) {{
                            counts[ID] += 1
                        }} else {{
                            counts[ID] = 1
                        }}
                    }}
                }}
                // Go through counts, if the value matches num_plots, push to table index
                for (var key in counts) {{
                    if (counts[key] == num_plots) {{
                        sMerge.selected.indices.push(sMerge.data["ID"].indexOf(key))
                    }}
                }}
                // update the DataTable
                sMerge.change.emit()
            """
        )
    )


show(column(row(figures), data_table, row(figures2)))

Turns out I was very close to getting it correct! I made some changes, including changing the ix to ID in the bokeh_frames.append(... bit, and a few minor changes to the javascript.

I looked closely at this code: Bokeh hover on two tables with different data sources

and I found that I must update the DataTable object itself (data_table in my code above), and not the ColumnDataSource (merge_source in my code above), to get the behavior I wanted. I don’t completely understand this at the moment - do I actually need both objects (merge_source and data_table in my javascript callback to actually do what I wanted? Seems like I’m essentially loading the same data twice when I could just access everything from DataTable.data.

Here is the minimum working example that now works. Use the lasso to select points in both plots on top.

# from here: https://discourse.bokeh.org/t/bokeh-hover-on-two-tables-with-different-data-sources/8202
import pandas as pd
import numpy as np
from bokeh.plotting import figure, show
from bokeh.models   import ColumnDataSource, DataTable, TableColumn, CustomJS
from bokeh.layouts  import row, column


dfs = []
data1 = {"ID": ['ID1', 'ID2', 'ID3', 'ID5', 'ID6', 'ID7',    'ID9',     'ID10'],
         "1_x": np.random.randint(10, 20, 8),
         "1_y": np.random.randint(10, 20, 8)}
df1 = pd.DataFrame(data1, index=data1['ID']).drop(columns=['ID'])
dfs.append(df1)

data2 = {"ID": ['ID1',     'ID3', 'ID4', 'ID5', 'ID6', 'ID7', 'ID8', 'ID10'],
         "2_x": np.random.randint(1, 100, 8),
         "2_y": np.random.randint(1, 100, 8)}
df2 = pd.DataFrame(data2, index=data2['ID']).drop(columns=['ID'])
dfs.append(df2)

dfm = pd.merge(df1, df2, left_index=True, right_index=True, how='outer')
# yields something like this:
#        1_x   1_y   2_x   2_y
# ID1   19.0  10.0  18.0  10.0
# ID10  14.0  14.0  11.0  15.0
# ID2   17.0  12.0   NaN   NaN
# ID3   19.0  18.0  11.0  11.0
# ID4    NaN   NaN  11.0  19.0
# ID5   14.0  10.0  13.0  10.0
# ID6   10.0  15.0  14.0  17.0
# ID7   15.0  14.0  11.0  19.0
# ID8    NaN   NaN  16.0  15.0
# ID9   11.0  10.0   NaN   NaN
merged_dict = {col: dfm[col].values for col in dfm.columns}
merged_dict['ID'] = dfm.index
merge_source = ColumnDataSource(data=merged_dict)
columns = [TableColumn(field=x, title=x) for x in dfm.columns] + [TableColumn(field='ID', title='ID')]
data_table = DataTable(source=merge_source, columns = columns, editable=True, width=400, height = 400)

figures      = [] # these figures are for the top - they have their own, independent data sources
figures2     = [] # these figures are for the bottom and draw from the merged data source
bokeh_frames = []
for i in range(len(dfs)):
    # Plot the data and save the html file
    figures.append(
        figure(width  = 200,
               height = 200,
               tools  = ["lasso_select", "reset"],
               title  = f"Plot {i+1}"))
    figures2.append(
        figure(width  = 200,
               height = 200,
               tools = [],
               title  = f"Plot {i+1} - intersection"))
    bokeh_frames.append(
        ColumnDataSource(data=dict(ID = list(dfs[i].index)   ,
                                   x  = list(dfs[i][f"{i+1}_x"]),
                                   y  = list(dfs[i][f"{i+1}_y"]),
                                   )))

    figures[-1 ].circle('x', 'y',
                       size=8, source = bokeh_frames[-1], selection_color="firebrick")
    figures2[-1].circle(f"{i+1}_x", f"{i+1}_y",
                       size=8, source = merge_source,     selection_color="firebrick")

# code to update the table as highlighted
for i in range(len(bokeh_frames)):
    thisdict = {f"s{j+1}": bokeh_frames[j] for j in range(len(bokeh_frames))}
    thisdict["sMerge"] = merge_source
    thisdict["num_plots"] = len(bokeh_frames)
    thisdict["data_table"] = data_table
    print("thisdict is: thisdict")
    print(thisdict)
    javascript_array_string = "["
    for j in range(len(bokeh_frames)-1):
        javascript_array_string += f"s{j+1}, "
    javascript_array_string += f"s{len(bokeh_frames)}]"
    bokeh_frames[i].selected.js_on_change(
        "indices",
        CustomJS(
            args=thisdict,
            code=f"""
                // clear out previously selected values
                sMerge.selected.indices = []
                // set up counts for the indices
                let counts = {{}};
                // store the selected values of each plot.
                // This uses a funny combination of python to drop a variable number of objects into the javascript array code
                var selected_lists = {javascript_array_string}
                for (var i = 0; i < selected_lists.length; i++) {{
                    // for each table, go through the selected indices, get the index value, and increment the count
                    for (var j = 0; j < selected_lists[i].selected.indices.length; j++) {{
                        var rowix = selected_lists[i].selected.indices[j]
                        // get the value in the 'ID' column of selected_lists[i]
                        var ID = selected_lists[i].data["ID"][rowix]
                        if (ID in counts) {{
                            counts[ID] += 1
                        }} else {{
                            counts[ID] = 1
                        }}
                    }}
                }}
                console.log("Next we are printing the counts")
                console.log(counts)
                // Go through counts, if the value matches num_plots, push to table index
                for (var key in counts) {{
                    if (counts[key] == num_plots) {{
                        //console.log("We are pushing the key: " + key + " to the sMerge.selected.indices. The index is: " + sMerge.data["ID"].indexOf(key))
                        sMerge.selected.indices.push(sMerge.data["ID"].indexOf(key))
                    }}
                }}
                console.log("Next we are printing the sMerge.selected.indices")
                console.log(sMerge.selected.indices)
                // update the DataTable
                data_table.change.emit()
            """
        )
    )


show(column(row(figures), data_table, row(figures2)))