Set intersection of two+ lasso tools

Turns out I was very close to getting it correct! I made some changes, including changing the ix to ID in the bokeh_frames.append(... bit, and a few minor changes to the javascript.

I looked closely at this code: Bokeh hover on two tables with different data sources

and I found that I must update the DataTable object itself (data_table in my code above), and not the ColumnDataSource (merge_source in my code above), to get the behavior I wanted. I don’t completely understand this at the moment - do I actually need both objects (merge_source and data_table in my javascript callback to actually do what I wanted? Seems like I’m essentially loading the same data twice when I could just access everything from DataTable.data.

Here is the minimum working example that now works. Use the lasso to select points in both plots on top.

# from here: https://discourse.bokeh.org/t/bokeh-hover-on-two-tables-with-different-data-sources/8202
import pandas as pd
import numpy as np
from bokeh.plotting import figure, show
from bokeh.models   import ColumnDataSource, DataTable, TableColumn, CustomJS
from bokeh.layouts  import row, column


dfs = []
data1 = {"ID": ['ID1', 'ID2', 'ID3', 'ID5', 'ID6', 'ID7',    'ID9',     'ID10'],
         "1_x": np.random.randint(10, 20, 8),
         "1_y": np.random.randint(10, 20, 8)}
df1 = pd.DataFrame(data1, index=data1['ID']).drop(columns=['ID'])
dfs.append(df1)

data2 = {"ID": ['ID1',     'ID3', 'ID4', 'ID5', 'ID6', 'ID7', 'ID8', 'ID10'],
         "2_x": np.random.randint(1, 100, 8),
         "2_y": np.random.randint(1, 100, 8)}
df2 = pd.DataFrame(data2, index=data2['ID']).drop(columns=['ID'])
dfs.append(df2)

dfm = pd.merge(df1, df2, left_index=True, right_index=True, how='outer')
# yields something like this:
#        1_x   1_y   2_x   2_y
# ID1   19.0  10.0  18.0  10.0
# ID10  14.0  14.0  11.0  15.0
# ID2   17.0  12.0   NaN   NaN
# ID3   19.0  18.0  11.0  11.0
# ID4    NaN   NaN  11.0  19.0
# ID5   14.0  10.0  13.0  10.0
# ID6   10.0  15.0  14.0  17.0
# ID7   15.0  14.0  11.0  19.0
# ID8    NaN   NaN  16.0  15.0
# ID9   11.0  10.0   NaN   NaN
merged_dict = {col: dfm[col].values for col in dfm.columns}
merged_dict['ID'] = dfm.index
merge_source = ColumnDataSource(data=merged_dict)
columns = [TableColumn(field=x, title=x) for x in dfm.columns] + [TableColumn(field='ID', title='ID')]
data_table = DataTable(source=merge_source, columns = columns, editable=True, width=400, height = 400)

figures      = [] # these figures are for the top - they have their own, independent data sources
figures2     = [] # these figures are for the bottom and draw from the merged data source
bokeh_frames = []
for i in range(len(dfs)):
    # Plot the data and save the html file
    figures.append(
        figure(width  = 200,
               height = 200,
               tools  = ["lasso_select", "reset"],
               title  = f"Plot {i+1}"))
    figures2.append(
        figure(width  = 200,
               height = 200,
               tools = [],
               title  = f"Plot {i+1} - intersection"))
    bokeh_frames.append(
        ColumnDataSource(data=dict(ID = list(dfs[i].index)   ,
                                   x  = list(dfs[i][f"{i+1}_x"]),
                                   y  = list(dfs[i][f"{i+1}_y"]),
                                   )))

    figures[-1 ].circle('x', 'y',
                       size=8, source = bokeh_frames[-1], selection_color="firebrick")
    figures2[-1].circle(f"{i+1}_x", f"{i+1}_y",
                       size=8, source = merge_source,     selection_color="firebrick")

# code to update the table as highlighted
for i in range(len(bokeh_frames)):
    thisdict = {f"s{j+1}": bokeh_frames[j] for j in range(len(bokeh_frames))}
    thisdict["sMerge"] = merge_source
    thisdict["num_plots"] = len(bokeh_frames)
    thisdict["data_table"] = data_table
    print("thisdict is: thisdict")
    print(thisdict)
    javascript_array_string = "["
    for j in range(len(bokeh_frames)-1):
        javascript_array_string += f"s{j+1}, "
    javascript_array_string += f"s{len(bokeh_frames)}]"
    bokeh_frames[i].selected.js_on_change(
        "indices",
        CustomJS(
            args=thisdict,
            code=f"""
                // clear out previously selected values
                sMerge.selected.indices = []
                // set up counts for the indices
                let counts = {{}};
                // store the selected values of each plot.
                // This uses a funny combination of python to drop a variable number of objects into the javascript array code
                var selected_lists = {javascript_array_string}
                for (var i = 0; i < selected_lists.length; i++) {{
                    // for each table, go through the selected indices, get the index value, and increment the count
                    for (var j = 0; j < selected_lists[i].selected.indices.length; j++) {{
                        var rowix = selected_lists[i].selected.indices[j]
                        // get the value in the 'ID' column of selected_lists[i]
                        var ID = selected_lists[i].data["ID"][rowix]
                        if (ID in counts) {{
                            counts[ID] += 1
                        }} else {{
                            counts[ID] = 1
                        }}
                    }}
                }}
                console.log("Next we are printing the counts")
                console.log(counts)
                // Go through counts, if the value matches num_plots, push to table index
                for (var key in counts) {{
                    if (counts[key] == num_plots) {{
                        //console.log("We are pushing the key: " + key + " to the sMerge.selected.indices. The index is: " + sMerge.data["ID"].indexOf(key))
                        sMerge.selected.indices.push(sMerge.data["ID"].indexOf(key))
                    }}
                }}
                console.log("Next we are printing the sMerge.selected.indices")
                console.log(sMerge.selected.indices)
                // update the DataTable
                data_table.change.emit()
            """
        )
    )


show(column(row(figures), data_table, row(figures2)))