diff --git a/junit-jupiter-params/src/main/java/org/junit/jupiter/params/ParameterizedTestSpiInstantiator.java b/junit-jupiter-params/src/main/java/org/junit/jupiter/params/ParameterizedTestSpiInstantiator.java index e3ebf4222fda..24c7163ac27a 100644 --- a/junit-jupiter-params/src/main/java/org/junit/jupiter/params/ParameterizedTestSpiInstantiator.java +++ b/junit-jupiter-params/src/main/java/org/junit/jupiter/params/ParameterizedTestSpiInstantiator.java @@ -18,7 +18,7 @@ import org.junit.jupiter.api.extension.ExtensionContext; import org.junit.platform.commons.JUnitException; import org.junit.platform.commons.PreconditionViolationException; -import org.junit.platform.commons.support.ModifierSupport; +import org.junit.platform.commons.util.Preconditions; import org.junit.platform.commons.util.ReflectionUtils; /** @@ -26,8 +26,10 @@ */ class ParameterizedTestSpiInstantiator { - static T instantiate(Class spiClass, Class clazz, ExtensionContext extensionContext) { - return extensionContext.getExecutableInvoker().invoke(findConstructor(spiClass, clazz)); + static T instantiate(Class spiInterface, Class implementationClass, + ExtensionContext extensionContext) { + return extensionContext.getExecutableInvoker() // + .invoke(findConstructor(spiInterface, implementationClass)); } /** @@ -37,26 +39,34 @@ static T instantiate(Class spiClass, Class clazz, ExtensionC * which takes precedence over any other constructor. If no default * constructor is found, it checks for a single constructor and returns it. */ + private static Constructor findConstructor(Class spiInterface, + Class implementationClass) { + + Preconditions.condition(!ReflectionUtils.isInnerClass(implementationClass), + () -> String.format("The %s [%s] must be either a top-level class or a static nested class", + spiInterface.getSimpleName(), implementationClass.getName())); + + return findDefaultConstructor(implementationClass) // + .orElseGet(() -> findSingleConstructor(spiInterface, implementationClass)); + } + @SuppressWarnings("unchecked") - private static Constructor findConstructor(Class spiClass, Class clazz) { - Optional> defaultConstructor = getFirstElement( - ReflectionUtils.findConstructors(clazz, it -> it.getParameterCount() == 0)); - if (defaultConstructor.isPresent()) { - return (Constructor) defaultConstructor.get(); - } - if (ModifierSupport.isNotStatic(clazz)) { - String message = String.format("The %s [%s] must be either a top-level class or a static nested class", - spiClass.getSimpleName(), clazz.getName()); - throw new JUnitException(message); - } + private static Optional> findDefaultConstructor(Class clazz) { + return getFirstElement(ReflectionUtils.findConstructors(clazz, it -> it.getParameterCount() == 0)) // + .map(it -> (Constructor) it); + } + + private static Constructor findSingleConstructor(Class spiInterface, + Class implementationClass) { + try { - return ReflectionUtils.getDeclaredConstructor(clazz); + return ReflectionUtils.getDeclaredConstructor(implementationClass); } catch (PreconditionViolationException ex) { String message = String.format( "Failed to find constructor for %s [%s]. " + "Please ensure that a no-argument or a single constructor exists.", - spiClass.getSimpleName(), clazz.getName()); + spiInterface.getSimpleName(), implementationClass.getName()); throw new JUnitException(message); } }