Skip to content

Commit

Permalink
Handle flash attributes on htmx redirects
Browse files Browse the repository at this point in the history
  • Loading branch information
xhaggi committed Oct 25, 2024
1 parent c7c8cdf commit aac224b
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import org.springframework.web.servlet.HandlerInterceptor;
import org.springframework.web.servlet.ModelAndView;
import org.springframework.web.servlet.View;
import org.springframework.web.servlet.support.ServletUriComponentsBuilder;

import java.lang.reflect.Method;
import java.time.Duration;
Expand All @@ -30,22 +29,22 @@ public HtmxHandlerInterceptor(ObjectMapper objectMapper, HtmxResponseHandlerMeth
@Override
public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler, ModelAndView modelAndView) throws Exception {
if (modelAndView != null) {
modelAndView.getModel().values().forEach(
value -> {
if (value instanceof HtmxResponse) {
buildAndRender((HtmxResponse) value, modelAndView, request, response);
} else if (value instanceof HtmxResponse.Builder) {
buildAndRender(((HtmxResponse.Builder) value).build(), modelAndView, request, response);
}
});
for (Object value : modelAndView.getModel().values()) {
if (value instanceof HtmxResponse) {
buildAndRender((HtmxResponse) value, modelAndView, request, response);
} else if (value instanceof HtmxResponse.Builder) {
buildAndRender(((HtmxResponse.Builder) value).build(), modelAndView, request, response);
}
}
}
}

private void buildAndRender(HtmxResponse htmxResponse, ModelAndView mav, HttpServletRequest request, HttpServletResponse response) {
View v = htmxResponseHandlerMethodReturnValueHandler.toView(htmxResponse);
try {
v.render(mav.getModel(), request, response);
htmxResponseHandlerMethodReturnValueHandler.addHxHeaders(htmxResponse, response);
// ModelAndViewContainer is not available here, so flash attributes won't work
htmxResponseHandlerMethodReturnValueHandler.addHxHeaders(htmxResponse, request, response, null);
} catch (Exception e) {
throw new RuntimeException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,29 @@

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.springframework.beans.factory.ObjectFactory;
import org.springframework.core.MethodParameter;
import org.springframework.lang.Nullable;
import org.springframework.ui.ModelMap;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.web.context.request.NativeWebRequest;
import org.springframework.web.method.support.HandlerMethodReturnValueHandler;
import org.springframework.web.method.support.ModelAndViewContainer;
import org.springframework.web.servlet.LocaleResolver;
import org.springframework.web.servlet.ModelAndView;
import org.springframework.web.servlet.View;
import org.springframework.web.servlet.ViewResolver;
import org.springframework.web.servlet.mvc.support.RedirectAttributes;
import org.springframework.web.servlet.support.RequestContextUtils;
import org.springframework.web.util.ContentCachingResponseWrapper;

import java.util.Collection;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.stream.Collectors;

public class HtmxResponseHandlerMethodReturnValueHandler implements HandlerMethodReturnValueHandler {
Expand Down Expand Up @@ -47,7 +54,10 @@ public void handleReturnValue(Object returnValue,
HtmxResponse htmxResponse = (HtmxResponse) returnValue;
mavContainer.setView(toView(htmxResponse));

addHxHeaders(htmxResponse, webRequest.getNativeResponse(HttpServletResponse.class));
HttpServletRequest request = webRequest.getNativeRequest(HttpServletRequest.class);
HttpServletResponse response = webRequest.getNativeResponse(HttpServletResponse.class);

addHxHeaders(htmxResponse, request, response, mavContainer);
}

View toView(HtmxResponse htmxResponse) {
Expand All @@ -74,16 +84,20 @@ View toView(HtmxResponse htmxResponse) {
};
}

void addHxHeaders(HtmxResponse htmxResponse, HttpServletResponse response) {
void addHxHeaders(HtmxResponse htmxResponse, HttpServletRequest request, HttpServletResponse response, @Nullable ModelAndViewContainer mavContainer) {
addHxTriggerHeaders(response, HtmxResponseHeader.HX_TRIGGER, htmxResponse.getTriggersInternal());
addHxTriggerHeaders(response, HtmxResponseHeader.HX_TRIGGER_AFTER_SETTLE, htmxResponse.getTriggersAfterSettleInternal());
addHxTriggerHeaders(response, HtmxResponseHeader.HX_TRIGGER_AFTER_SWAP, htmxResponse.getTriggersAfterSwapInternal());

if (htmxResponse.getLocation() != null) {
if (htmxResponse.getLocation().hasContextData()) {
setHeaderJsonValue(response, HtmxResponseHeader.HX_LOCATION.getValue(), htmxResponse.getLocation());
HtmxLocation location = htmxResponse.getLocation();
if (mavContainer != null) {
saveFlashAttributes(mavContainer, request, response, location.getPath());
}
if (location.hasContextData()) {
setHeaderJsonValue(response, HtmxResponseHeader.HX_LOCATION.getValue(), location);
} else {
response.setHeader(HtmxResponseHeader.HX_LOCATION.getValue(), htmxResponse.getLocation().getPath());
response.setHeader(HtmxResponseHeader.HX_LOCATION.getValue(), location.getPath());
}
}
if (htmxResponse.getReplaceUrl() != null) {
Expand All @@ -93,6 +107,9 @@ void addHxHeaders(HtmxResponse htmxResponse, HttpServletResponse response) {
response.setHeader(HtmxResponseHeader.HX_PUSH_URL.getValue(), htmxResponse.getPushUrl());
}
if (htmxResponse.getRedirect() != null) {
if (mavContainer != null) {
saveFlashAttributes(mavContainer, request, response, htmxResponse.getRedirect());
}
response.setHeader(HtmxResponseHeader.HX_REDIRECT.getValue(), htmxResponse.getRedirect());
}
if (htmxResponse.isRefresh()) {
Expand Down Expand Up @@ -139,4 +156,21 @@ private void setHeaderJsonValue(HttpServletResponse response, String name, Objec
throw new IllegalArgumentException("Unable to set header " + name + " to " + value, e);
}
}

private void saveFlashAttributes(ModelAndViewContainer mavContainer, HttpServletRequest request, HttpServletResponse response, String location) {
mavContainer.setRedirectModelScenario(true);
ModelMap model = mavContainer.getModel();

if (model instanceof RedirectAttributes redirectAttributes) {
Map<String, ?> flashAttributes = redirectAttributes.getFlashAttributes();
if (!CollectionUtils.isEmpty(flashAttributes)) {
if (request != null) {
RequestContextUtils.getOutputFlashMap(request).putAll(flashAttributes);
if (response != null) {
RequestContextUtils.saveOutputFlashMap(location, request, response);
}
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.servlet.ModelAndView;
import org.springframework.web.servlet.mvc.support.RedirectAttributes;

import java.time.Duration;
import java.util.Map;
Expand Down Expand Up @@ -34,6 +35,12 @@ public HtmxResponse hxLocationWithoutContextData() {
return HtmxResponse.builder().location("/path").build();
}

@GetMapping("/hx-location-with-flash-attributes")
public HtmxResponse hxLocationWithoutContextData(RedirectAttributes redirectAttributes) {
redirectAttributes.addFlashAttribute("flash", "test");
return HtmxResponse.builder().location("/path").build();
}

@GetMapping("/hx-push-url")
public HtmxResponse hxPushUrl() {
return HtmxResponse.builder().pushUrl("/path").build();
Expand All @@ -44,6 +51,12 @@ public HtmxResponse hxRedirect() {
return HtmxResponse.builder().redirect("/path").build();
}

@GetMapping("/hx-redirect-with-flash-attributes")
public HtmxResponse hxRedirectWithFlashAttributes(RedirectAttributes redirectAttributes) {
redirectAttributes.addFlashAttribute("flash", "test");
return HtmxResponse.builder().redirect("/path").build();
}

@GetMapping("/hx-refresh")
public HtmxResponse hxRefresh() {
return HtmxResponse.builder().refresh().build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@

import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.header;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*;

@WebMvcTest(HtmxResponseHandlerMethodReturnValueHandlerController.class)
@WithMockUser
Expand All @@ -33,6 +32,14 @@ public void testHxLocationWithoutContextData() throws Exception {
.andExpect(header().string("HX-Location", "/path"));
}

@Test
public void testHxLocationWithFlashAttributes() throws Exception {
mockMvc.perform(get("/hvhi/hx-location-with-flash-attributes"))
.andExpect(status().isOk())
.andExpect(header().string("HX-Location", "/path"))
.andExpect(flash().attribute("flash", "test"));
}

@Test
public void testHxPushUrl() throws Exception {
mockMvc.perform(get("/hvhi/hx-push-url"))
Expand All @@ -47,6 +54,14 @@ public void testHxRedirect() throws Exception {
.andExpect(header().string("HX-Redirect", "/path"));
}

@Test
public void testHxRedirectWithFlashAttributes() throws Exception {
mockMvc.perform(get("/hvhi/hx-redirect-with-flash-attributes"))
.andExpect(status().isOk())
.andExpect(header().string("HX-Redirect", "/path"))
.andExpect(flash().attribute("flash", "test"));
}

@Test
public void testHxRefresh() throws Exception {
mockMvc.perform(get("/hvhi/hx-refresh"))
Expand Down

0 comments on commit aac224b

Please sign in to comment.