Skip to content

Commit

Permalink
Use ReflectionUtils.isInnerClass and polish implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
marcphilipp committed Sep 25, 2024
1 parent b4e35c4 commit 48cd00d
Showing 1 changed file with 26 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,18 @@
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.platform.commons.JUnitException;
import org.junit.platform.commons.PreconditionViolationException;
import org.junit.platform.commons.support.ModifierSupport;
import org.junit.platform.commons.util.Preconditions;
import org.junit.platform.commons.util.ReflectionUtils;

/**
* @since 5.12
*/
class ParameterizedTestSpiInstantiator {

static <T> T instantiate(Class<T> spiClass, Class<? extends T> clazz, ExtensionContext extensionContext) {
return extensionContext.getExecutableInvoker().invoke(findConstructor(spiClass, clazz));
static <T> T instantiate(Class<T> spiInterface, Class<? extends T> implementationClass,
ExtensionContext extensionContext) {
return extensionContext.getExecutableInvoker() //
.invoke(findConstructor(spiInterface, implementationClass));
}

/**
Expand All @@ -37,26 +39,34 @@ static <T> T instantiate(Class<T> spiClass, Class<? extends T> clazz, ExtensionC
* which takes precedence over any other constructor. If no default
* constructor is found, it checks for a single constructor and returns it.
*/
private static <T, V extends T> Constructor<? extends V> findConstructor(Class<T> spiInterface,
Class<V> implementationClass) {

Preconditions.condition(!ReflectionUtils.isInnerClass(implementationClass),
() -> String.format("The %s [%s] must be either a top-level class or a static nested class",
spiInterface.getSimpleName(), implementationClass.getName()));

return findDefaultConstructor(implementationClass) //
.orElseGet(() -> findSingleConstructor(spiInterface, implementationClass));
}

@SuppressWarnings("unchecked")
private static <T> Constructor<? extends T> findConstructor(Class<T> spiClass, Class<? extends T> clazz) {
Optional<Constructor<?>> defaultConstructor = getFirstElement(
ReflectionUtils.findConstructors(clazz, it -> it.getParameterCount() == 0));
if (defaultConstructor.isPresent()) {
return (Constructor<? extends T>) defaultConstructor.get();
}
if (ModifierSupport.isNotStatic(clazz)) {
String message = String.format("The %s [%s] must be either a top-level class or a static nested class",
spiClass.getSimpleName(), clazz.getName());
throw new JUnitException(message);
}
private static <T> Optional<Constructor<T>> findDefaultConstructor(Class<T> clazz) {
return getFirstElement(ReflectionUtils.findConstructors(clazz, it -> it.getParameterCount() == 0)) //
.map(it -> (Constructor<T>) it);
}

private static <T, V extends T> Constructor<V> findSingleConstructor(Class<T> spiInterface,
Class<V> implementationClass) {

try {
return ReflectionUtils.getDeclaredConstructor(clazz);
return ReflectionUtils.getDeclaredConstructor(implementationClass);
}
catch (PreconditionViolationException ex) {
String message = String.format(
"Failed to find constructor for %s [%s]. "
+ "Please ensure that a no-argument or a single constructor exists.",
spiClass.getSimpleName(), clazz.getName());
spiInterface.getSimpleName(), implementationClass.getName());
throw new JUnitException(message);
}
}
Expand Down

0 comments on commit 48cd00d

Please sign in to comment.