diff --git a/src/components/NodeWrapper.tsx b/src/components/NodeWrapper.tsx index 781e47db..9c12996b 100644 --- a/src/components/NodeWrapper.tsx +++ b/src/components/NodeWrapper.tsx @@ -14,25 +14,42 @@ import { NodeViewDescriptorsContext, } from "../contexts/NodeViewPositionsContext.js"; +export function findChildDesc( + pos: number, + posToDesc: Map +) { + const positions = Array.from(posToDesc.keys()).sort((a, b) => b - a); + + let parentPos = null; + for (const nodePos of positions) { + if (nodePos < pos) break; + + parentPos = nodePos; + } + + return parentPos === null ? null : posToDesc.get(parentPos); +} + type NodeWrapperProps = { children: ReactNode; pos: number; }; export function NodeWrapper({ children, pos }: NodeWrapperProps) { - const { posToDesc: posToDOM, domToDesc: domToPos } = useContext( - NodeViewDescriptorsContext - ); + const { posToDesc, domToDesc } = useContext(NodeViewDescriptorsContext); const ref = useRef(null); useLayoutEffect(() => { if (!ref.current) return; + + const childDesc = findChildDesc(pos, posToDesc); + const desc: NodeViewDescriptor = { pos, dom: ref.current, - contentDOM: null, + contentDOM: childDesc?.dom.parentNode ?? null, }; - posToDOM.set(pos, desc); - domToPos.set(ref.current, desc); + posToDesc.set(pos, desc); + domToDesc.set(ref.current, desc); }); const child = Children.only(children); diff --git a/src/components/TextNodeWrapper.tsx b/src/components/TextNodeWrapper.tsx index 2bf91766..2bf45a72 100644 --- a/src/components/TextNodeWrapper.tsx +++ b/src/components/TextNodeWrapper.tsx @@ -18,7 +18,7 @@ export class TextNodeWrapper extends Component { const textNode = findDOMNode(this); if (!textNode) return; - const { posToDesc: posToDOM, domToDesc: domToPos } = this + const { posToDesc, domToDesc } = this .context as NodeViewDescriptorsContextValue; const desc: NodeViewDescriptor = { @@ -26,8 +26,8 @@ export class TextNodeWrapper extends Component { dom: textNode, contentDOM: null, }; - posToDOM.set(this.props.pos, desc); - domToPos.set(textNode, desc); + posToDesc.set(this.props.pos, desc); + domToDesc.set(textNode, desc); } componentDidUpdate(): void { @@ -36,7 +36,7 @@ export class TextNodeWrapper extends Component { const textNode = findDOMNode(this); if (!textNode) return; - const { posToDesc: posToDOM, domToDesc: domToPos } = this + const { posToDesc, domToDesc } = this .context as NodeViewDescriptorsContextValue; const desc: NodeViewDescriptor = { @@ -44,8 +44,8 @@ export class TextNodeWrapper extends Component { dom: textNode, contentDOM: null, }; - posToDOM.set(this.props.pos, desc); - domToPos.set(textNode, desc); + posToDesc.set(this.props.pos, desc); + domToDesc.set(textNode, desc); } render() {