Skip to content

Commit

Permalink
[material-ui] Improve getReactElementRef() utils (mui#43022)
Browse files Browse the repository at this point in the history
Co-authored-by: Aarón García Hervás <[email protected]>
  • Loading branch information
sai6855 and aarongarciah committed Sep 19, 2024
1 parent 82a6448 commit 3c83c7d
Show file tree
Hide file tree
Showing 21 changed files with 116 additions and 52 deletions.
4 changes: 2 additions & 2 deletions packages/mui-base/src/ClickAwayListener/ClickAwayListener.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {
unstable_ownerDocument as ownerDocument,
unstable_useForkRef as useForkRef,
unstable_useEventCallback as useEventCallback,
unstable_getReactNodeRef as getReactNodeRef,
unstable_getReactElementRef as getReactElementRef,
} from '@mui/utils';

// TODO: return `EventHandlerName extends `on${infer EventName}` ? Lowercase<EventName> : never` once generatePropTypes runs with TS 4.1
Expand Down Expand Up @@ -95,7 +95,7 @@ function ClickAwayListener(props: ClickAwayListenerProps): React.JSX.Element {
};
}, []);

const handleRef = useForkRef(getReactNodeRef(children), nodeRef);
const handleRef = useForkRef(getReactElementRef(children), nodeRef);

// The handler doesn't take event.defaultPrevented into account:
//
Expand Down
4 changes: 2 additions & 2 deletions packages/mui-base/src/FocusTrap/FocusTrap.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {
elementAcceptingRef,
unstable_useForkRef as useForkRef,
unstable_ownerDocument as ownerDocument,
unstable_getReactNodeRef as getReactNodeRef,
unstable_getReactElementRef as getReactElementRef,
} from '@mui/utils';
import { FocusTrapProps } from './FocusTrap.types';

Expand Down Expand Up @@ -153,7 +153,7 @@ function FocusTrap(props: FocusTrapProps): React.JSX.Element {
const activated = React.useRef(false);

const rootRef = React.useRef<HTMLElement>(null);
const handleRef = useForkRef(getReactNodeRef(children), rootRef);
const handleRef = useForkRef(getReactElementRef(children), rootRef);
const lastKeydown = React.useRef<KeyboardEvent | null>(null);

React.useEffect(() => {
Expand Down
8 changes: 6 additions & 2 deletions packages/mui-base/src/Portal/Portal.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import * as React from 'react';
import * as ReactDOM from 'react-dom';
import PropTypes from 'prop-types';
import getReactNodeRef from '@mui/utils/getReactNodeRef';
import getReactElementRef from '@mui/utils/getReactElementRef';
import {
exactProp,
HTMLElementType,
Expand Down Expand Up @@ -34,7 +34,11 @@ const Portal = React.forwardRef(function Portal(
) {
const { children, container, disablePortal = false } = props;
const [mountNode, setMountNode] = React.useState<ReturnType<typeof getContainer>>(null);
const handleRef = useForkRef(getReactNodeRef(children), forwardedRef);

const handleRef = useForkRef(
React.isValidElement(children) ? getReactElementRef(children) : null,
forwardedRef,
);

useEnhancedEffect(() => {
if (!disablePortal) {
Expand Down
4 changes: 2 additions & 2 deletions packages/mui-joy/src/Tooltip/Tooltip.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import {
unstable_useId as useId,
unstable_useTimeout as useTimeout,
unstable_Timeout as Timeout,
unstable_getReactNodeRef as getReactNodeRef,
unstable_getReactElementRef as getReactElementRef,
} from '@mui/utils';
import { Popper, unstable_composeClasses as composeClasses } from '@mui/base';
import { OverridableComponent } from '@mui/types';
Expand Down Expand Up @@ -416,7 +416,7 @@ const Tooltip = React.forwardRef(function Tooltip(inProps, ref) {
}, [handleClose, open]);

const handleUseRef = useForkRef(setChildNode, ref);
const handleRef = useForkRef(getReactNodeRef(children), handleUseRef);
const handleRef = useForkRef(getReactElementRef(children), handleUseRef);

// There is no point in displaying an empty tooltip.
if (typeof title !== 'number' && !title) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {
unstable_useForkRef as useForkRef,
unstable_useEventCallback as useEventCallback,
} from '@mui/utils';
import getReactNodeRef from '@mui/utils/getReactNodeRef';
import getReactElementRef from '@mui/utils/getReactElementRef';

// TODO: return `EventHandlerName extends `on${infer EventName}` ? Lowercase<EventName> : never` once generatePropTypes runs with TS 4.1
function mapEventPropToEvent(
Expand Down Expand Up @@ -96,7 +96,7 @@ function ClickAwayListener(props: ClickAwayListenerProps): React.JSX.Element {
};
}, []);

const handleRef = useForkRef(getReactNodeRef(children), nodeRef);
const handleRef = useForkRef(getReactElementRef(children), nodeRef);

// The handler doesn't take event.defaultPrevented into account:
//
Expand Down
4 changes: 2 additions & 2 deletions packages/mui-material/src/Fade/Fade.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import * as React from 'react';
import PropTypes from 'prop-types';
import { Transition } from 'react-transition-group';
import elementAcceptingRef from '@mui/utils/elementAcceptingRef';
import getReactNodeRef from '@mui/utils/getReactNodeRef';
import getReactElementRef from '@mui/utils/getReactElementRef';
import { useTheme } from '../zero-styled';
import { reflow, getTransitionProps } from '../transitions/utils';
import useForkRef from '../utils/useForkRef';
Expand Down Expand Up @@ -49,7 +49,7 @@ const Fade = React.forwardRef(function Fade(props, ref) {

const enableStrictModeCompat = true;
const nodeRef = React.useRef(null);
const handleRef = useForkRef(nodeRef, getReactNodeRef(children), ref);
const handleRef = useForkRef(nodeRef, getReactElementRef(children), ref);

const normalizedTransitionCallback = (callback) => (maybeIsAppearing) => {
if (callback) {
Expand Down
4 changes: 2 additions & 2 deletions packages/mui-material/src/Grow/Grow.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import * as React from 'react';
import PropTypes from 'prop-types';
import useTimeout from '@mui/utils/useTimeout';
import elementAcceptingRef from '@mui/utils/elementAcceptingRef';
import getReactNodeRef from '@mui/utils/getReactNodeRef';
import getReactElementRef from '@mui/utils/getReactElementRef';
import { Transition } from 'react-transition-group';
import { useTheme } from '../zero-styled';
import { getTransitionProps, reflow } from '../transitions/utils';
Expand Down Expand Up @@ -62,7 +62,7 @@ const Grow = React.forwardRef(function Grow(props, ref) {
const theme = useTheme();

const nodeRef = React.useRef(null);
const handleRef = useForkRef(nodeRef, getReactNodeRef(children), ref);
const handleRef = useForkRef(nodeRef, getReactElementRef(children), ref);

const normalizedTransitionCallback = (callback) => (maybeIsAppearing) => {
if (callback) {
Expand Down
8 changes: 6 additions & 2 deletions packages/mui-material/src/Portal/Portal.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {
unstable_useEnhancedEffect as useEnhancedEffect,
unstable_useForkRef as useForkRef,
unstable_setRef as setRef,
unstable_getReactNodeRef as getReactNodeRef,
unstable_getReactElementRef as getReactElementRef,
} from '@mui/utils';
import { PortalProps } from './Portal.types';

Expand All @@ -34,7 +34,11 @@ const Portal = React.forwardRef(function Portal(
) {
const { children, container, disablePortal = false } = props;
const [mountNode, setMountNode] = React.useState<ReturnType<typeof getContainer>>(null);
const handleRef = useForkRef(getReactNodeRef(children), forwardedRef);

const handleRef = useForkRef(
React.isValidElement(children) ? getReactElementRef(children) : null,
forwardedRef,
);

useEnhancedEffect(() => {
if (!disablePortal) {
Expand Down
4 changes: 2 additions & 2 deletions packages/mui-material/src/Select/Select.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import * as React from 'react';
import PropTypes from 'prop-types';
import clsx from 'clsx';
import deepmerge from '@mui/utils/deepmerge';
import getReactNodeRef from '@mui/utils/getReactNodeRef';
import getReactElementRef from '@mui/utils/getReactElementRef';
import SelectInput from './SelectInput';
import formControlState from '../FormControl/formControlState';
import useFormControl from '../FormControl/useFormControl';
Expand Down Expand Up @@ -86,7 +86,7 @@ const Select = React.forwardRef(function Select(inProps, ref) {
filled: <StyledFilledInput ownerState={ownerState} />,
}[variant];

const inputComponentRef = useForkRef(ref, getReactNodeRef(InputComponent));
const inputComponentRef = useForkRef(ref, getReactElementRef(InputComponent));

return (
<React.Fragment>
Expand Down
4 changes: 2 additions & 2 deletions packages/mui-material/src/Slide/Slide.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { Transition } from 'react-transition-group';
import chainPropTypes from '@mui/utils/chainPropTypes';
import HTMLElementType from '@mui/utils/HTMLElementType';
import elementAcceptingRef from '@mui/utils/elementAcceptingRef';
import getReactNodeRef from '@mui/utils/getReactNodeRef';
import getReactElementRef from '@mui/utils/getReactElementRef';
import debounce from '../utils/debounce';
import useForkRef from '../utils/useForkRef';
import { useTheme } from '../zero-styled';
Expand Down Expand Up @@ -120,7 +120,7 @@ const Slide = React.forwardRef(function Slide(props, ref) {
} = props;

const childrenRef = React.useRef(null);
const handleRef = useForkRef(getReactNodeRef(children), childrenRef, ref);
const handleRef = useForkRef(getReactElementRef(children), childrenRef, ref);

const normalizedTransitionCallback = (callback) => (isAppearing) => {
if (callback) {
Expand Down
4 changes: 2 additions & 2 deletions packages/mui-material/src/Tooltip/Tooltip.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import { alpha } from '@mui/system/colorManipulator';
import { useRtl } from '@mui/system/RtlProvider';
import isFocusVisible from '@mui/utils/isFocusVisible';
import appendOwnerState from '@mui/utils/appendOwnerState';
import getReactNodeRef from '@mui/utils/getReactNodeRef';
import getReactElementRef from '@mui/utils/getReactElementRef';
import { styled, useTheme } from '../zero-styled';
import memoTheme from '../utils/memoTheme';
import { useDefaultProps } from '../DefaultPropsProvider';
Expand Down Expand Up @@ -555,7 +555,7 @@ const Tooltip = React.forwardRef(function Tooltip(inProps, ref) {
};
}, [handleClose, open]);

const handleRef = useForkRef(getReactNodeRef(children), setChildNode, ref);
const handleRef = useForkRef(getReactElementRef(children), setChildNode, ref);

// There is no point in displaying an empty tooltip.
// So we exclude all falsy values, except 0, which is valid.
Expand Down
4 changes: 2 additions & 2 deletions packages/mui-material/src/Unstable_TrapFocus/FocusTrap.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {
elementAcceptingRef,
unstable_useForkRef as useForkRef,
unstable_ownerDocument as ownerDocument,
unstable_getReactNodeRef as getReactNodeRef,
unstable_getReactElementRef as getReactElementRef,
} from '@mui/utils';
import { FocusTrapProps } from './FocusTrap.types';

Expand Down Expand Up @@ -145,7 +145,7 @@ function FocusTrap(props: FocusTrapProps): React.JSX.Element {
const activated = React.useRef(false);

const rootRef = React.useRef<HTMLElement>(null);
const handleRef = useForkRef(getReactNodeRef(children), rootRef);
const handleRef = useForkRef(getReactElementRef(children), rootRef);
const lastKeydown = React.useRef<KeyboardEvent | null>(null);

React.useEffect(() => {
Expand Down
4 changes: 2 additions & 2 deletions packages/mui-material/src/Zoom/Zoom.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import * as React from 'react';
import PropTypes from 'prop-types';
import { Transition } from 'react-transition-group';
import elementAcceptingRef from '@mui/utils/elementAcceptingRef';
import getReactNodeRef from '@mui/utils/getReactNodeRef';
import getReactElementRef from '@mui/utils/getReactElementRef';
import { useTheme } from '../zero-styled';
import { reflow, getTransitionProps } from '../transitions/utils';
import useForkRef from '../utils/useForkRef';
Expand Down Expand Up @@ -49,7 +49,7 @@ const Zoom = React.forwardRef(function Zoom(props, ref) {
} = props;

const nodeRef = React.useRef(null);
const handleRef = useForkRef(nodeRef, getReactNodeRef(children), ref);
const handleRef = useForkRef(nodeRef, getReactElementRef(children), ref);

const normalizedTransitionCallback = (callback) => (maybeIsAppearing) => {
if (callback) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import getReactElementRef from '@mui/utils/getReactElementRef';
import * as React from 'react';

// @ts-expect-error
getReactElementRef(false);

// @ts-expect-error
getReactElementRef(null);

// @ts-expect-error
getReactElementRef(undefined);

// @ts-expect-error
getReactElementRef(1);

// @ts-expect-error
getReactElementRef([<div key="1" />, <div key="2" />]);

getReactElementRef(<div />);
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import { expect } from 'chai';
import getReactElementRef from '@mui/utils/getReactElementRef';
import * as React from 'react';

describe('getReactElementRef', () => {
it('should return undefined when not used correctly', () => {
// @ts-expect-error
expect(getReactElementRef(false)).to.equal(undefined);
// @ts-expect-error
expect(getReactElementRef()).to.equal(undefined);
// @ts-expect-error
expect(getReactElementRef(1)).to.equal(undefined);

const children = [<div key="1" />, <div key="2" />];
// @ts-expect-error
expect(getReactElementRef(children)).to.equal(undefined);
});

it('should return the ref of a React element', () => {
const ref = React.createRef<HTMLDivElement>();
const element = <div ref={ref} />;
expect(getReactElementRef(element)).to.equal(ref);
});

it('should return null for a fragment', () => {
const element = (
<React.Fragment>
<p>Hello</p>
<p>Hello</p>
</React.Fragment>
);
expect(getReactElementRef(element)).to.equal(null);
});

it('should return null for element with no ref', () => {
const element = <div />;
expect(getReactElementRef(element)).to.equal(null);
});
});
20 changes: 20 additions & 0 deletions packages/mui-utils/src/getReactElementRef/getReactElementRef.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import * as React from 'react';

/**
* Returns the ref of a React element handling differences between React 19 and older versions.
* It will throw runtime error if the element is not a valid React element.
*
* @param element React.ReactElement
* @returns React.Ref<any> | null | undefined
*/
export default function getReactElementRef(
element: React.ReactElement,
): React.Ref<any> | null | undefined {
// 'ref' is passed as prop in React 19, whereas 'ref' is directly attached to children in older versions
if (parseInt(React.version, 10) >= 19) {
return element.props?.ref;
}
// @ts-expect-error element.ref is not included in the ReactElement type
// https://github.com/DefinitelyTyped/DefinitelyTyped/discussions/70189
return element?.ref;
}
1 change: 1 addition & 0 deletions packages/mui-utils/src/getReactElementRef/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
export { default } from './getReactElementRef';
22 changes: 0 additions & 22 deletions packages/mui-utils/src/getReactNodeRef/getReactNodeRef.ts

This file was deleted.

1 change: 0 additions & 1 deletion packages/mui-utils/src/getReactNodeRef/index.ts

This file was deleted.

2 changes: 1 addition & 1 deletion packages/mui-utils/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,5 @@ export { default as unstable_useSlotProps } from './useSlotProps';
export type { UseSlotPropsParameters, UseSlotPropsResult } from './useSlotProps';
export { default as unstable_resolveComponentProps } from './resolveComponentProps';
export { default as unstable_extractEventHandlers } from './extractEventHandlers';
export { default as unstable_getReactNodeRef } from './getReactNodeRef';
export { default as unstable_getReactElementRef } from './getReactElementRef';
export * from './types';
4 changes: 2 additions & 2 deletions packages/mui-utils/src/useForkRef/useForkRef.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import * as React from 'react';
import { expect } from 'chai';
import { createRenderer, screen } from '@mui/internal-test-utils';
import useForkRef from './useForkRef';
import getReactNodeRef from '../getReactNodeRef';
import getReactElementRef from '../getReactElementRef';

describe('useForkRef', () => {
const { render } = createRenderer();
Expand Down Expand Up @@ -48,7 +48,7 @@ describe('useForkRef', () => {
it('does nothing if none of the forked branches requires a ref', () => {
const Outer = React.forwardRef(function Outer(props, ref) {
const { children } = props;
const handleRef = useForkRef(getReactNodeRef(children), ref);
const handleRef = useForkRef(getReactElementRef(children), ref);

return React.cloneElement(children, { ref: handleRef });
});
Expand Down

0 comments on commit 3c83c7d

Please sign in to comment.