Skip to content

Commit

Permalink
Fix dynamic JS import and add missing scripts for colormaps (#193)
Browse files Browse the repository at this point in the history
* Fix dynamic JS import and add missing scripts for colormaps

* included the fogotted view change logics

* Add type hints to comply with static type checking

* Fixed the typing to use built-in generics in Python 3.9+
  • Loading branch information
Navxihziq authored Jun 18, 2024
1 parent c83a23e commit 77c692c
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 65 deletions.
14 changes: 12 additions & 2 deletions streamlit_folium/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import branca
import folium
import folium.elements
import folium.plugins
import streamlit as st
import streamlit.components.v1 as components
Expand All @@ -22,6 +23,7 @@
_component_func = components.declare_component(
"st_folium", url="http://localhost:3001"
)

else:
parent_dir = os.path.dirname(os.path.abspath(__file__))
build_dir = os.path.join(parent_dir, "frontend/build")
Expand Down Expand Up @@ -367,6 +369,8 @@ def bounds_to_dict(bounds_list: list[list[float]]) -> dict[str, dict[str, float]
st.code(layer_control_string)

def walk(fig):
if isinstance(fig, branca.colormap.ColorMap):
yield fig
if isinstance(fig, folium.plugins.DualMap):
yield from walk(fig.m1)
yield from walk(fig.m2)
Expand All @@ -376,10 +380,16 @@ def walk(fig):
for child in fig._children.values():
yield from walk(child)

css_links = []
js_links = []
css_links: list[str] = []
js_links: list[str] = []

for elem in walk(folium_map):
if isinstance(elem, branca.colormap.ColorMap):
# manually add d3.js
js_links.insert(
0, "https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.5/d3.min.js"
)
js_links.insert(0, "https://d3js.org/d3.v4.min.js")
css_links.extend([href for _, href in elem.default_css])
js_links.extend([src for _, src in elem.default_js])

Expand Down
115 changes: 52 additions & 63 deletions streamlit_folium/frontend/src/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ function onLayerClick(e: any) {
debouncedUpdateComponentValue(window.map)
}

function getPixelatedStyles(pixelated: boolean) {
function getPixelatedStyles(pixelated: boolean) {
if (pixelated) {
const styles = `
.leaflet-image-layer {
Expand All @@ -164,7 +164,6 @@ function getPixelatedStyles(pixelated: boolean) {
}
`
return styles

}

window.initComponent = (map: any, return_on_hover: boolean) => {
Expand All @@ -190,7 +189,7 @@ window.initComponent = (map: any, return_on_hover: boolean) => {
* the component is initially loaded, and then again every time the
* component gets new data from Python.
*/
function onRender(event: Event): void {
async function onRender(event: Event) {
// Get the RenderData from the event
const data = (event as CustomEvent<RenderData>).detail

Expand All @@ -209,30 +208,54 @@ function onRender(event: Event): void {
const layer_control: string = data.args["layer_control"]
const pixelated: boolean = data.args["pixelated"]

var finalizeOnRender = () => {
// load scripts
const loadScripts = async () => {
for (const link of js_links) {
// use promise to load scripts synchronously
await new Promise((resolve, reject) => {
const script = document.createElement("script")
script.src = link
script.async = false
script.onload = resolve
script.onerror = reject
window.document.body.appendChild(script)
})
}

css_links.forEach((link) => {
const linkTag = document.createElement("link")
linkTag.rel = "stylesheet"
linkTag.href = link
window.document.head.appendChild(linkTag)
})

const style = document.createElement("style")
style.innerHTML = getPixelatedStyles(pixelated)
window.document.head.appendChild(style)
}

// finalize rendering
const finalizeOnRender = () => {
if (
feature_group !== window.__GLOBAL_DATA__.last_feature_group ||
layer_control !== window.__GLOBAL_DATA__.last_layer_control
) {
// remove previous feature group and layer control
if (window.feature_group && window.feature_group.length > 0) {
window.feature_group.forEach((layer: Layer) => {
window.map.removeLayer(layer);
});
window.map.removeLayer(layer)
})
}

if (window.layer_control) {
window.map.removeControl(window.layer_control)
}

// update feature group and layer control cache
window.__GLOBAL_DATA__.last_feature_group = feature_group
window.__GLOBAL_DATA__.last_layer_control = layer_control

if (feature_group){
// Though using `eval` is generally a bad idea, we're using it here
// because we're evaluating code that we've generated ourselves on the
// Python side. This is safe because we're not evaluating user input, so this
// couldn't be used to execute arbitrary code.

if (feature_group) {
// eslint-disable-next-line
eval(feature_group + layer_control)
for (let key in window.map._layers) {
Expand Down Expand Up @@ -296,7 +319,6 @@ function onRender(event: Event): void {
document.body.appendChild(a)
}

const render_script = document.createElement("script")
// HACK -- update the folium-generated JS to add, most importantly,
// the map to this global variable so that it can be used elsewhere
// in the script.
Expand All @@ -322,60 +344,27 @@ function onRender(event: Event): void {
parent_div?.classList.remove("single")
parent_div?.classList.add("double")
}
}
await loadScripts().then(() => {
const render_script = document.createElement("script")

// This is only loaded once, from the onload callback
var postLoad = () => {
if (!window.map) {
render_script.innerHTML =
if (!window.map) {
render_script.innerHTML =
script +
`window.map = map_div; window.initComponent(map_div, ${return_on_hover});`
document.body.appendChild(render_script)
const html_div = document.createElement("div")
html_div.innerHTML = html
document.body.appendChild(html_div)
const styles = getPixelatedStyles(pixelated)
var styleSheet = document.createElement("style")
styleSheet.innerText = styles
document.head.appendChild(styleSheet)
}
finalizeOnRender();
}

if (js_links.length === 0) {
postLoad();
} else {
// make sure dependent js files are loaded
// before we initialize the component
var count = 0;
js_links.forEach((elem) => {
var scr = document.createElement('script');
scr.src = elem;
scr.async = false;
scr.onload = () => {
count -= 1;
if(count === 0) {
setTimeout(postLoad, 0);
}
};
document.head.appendChild(scr);
count += 1;
});
`window.map = map_div; window.initComponent(map_div, ${return_on_hover});`
document.body.appendChild(render_script)
const html_div = document.createElement("div")
html_div.innerHTML = html
document.body.appendChild(html_div)
const styles = getPixelatedStyles(pixelated)
var styleSheet = document.createElement("style")
styleSheet.innerText = styles
document.head.appendChild(styleSheet)
}

// css is okay regardless loading order
css_links.forEach((elem) => {
var link = document.createElement('link');
link.rel = "stylesheet";
link.type = "text/css";
link.href = elem;
document.head.appendChild(link);
});
Streamlit.setFrameHeight()
}
} else {
finalizeOnRender();
finalizeOnRender()
})
}

finalizeOnRender()
}

// Attach our `onRender` handler to Streamlit's render event.
Expand Down

0 comments on commit 77c692c

Please sign in to comment.