Skip to content

Commit

Permalink
Use full header type in response header instances (#1697)
Browse files Browse the repository at this point in the history
* Use `Header'` in response headers

Use `Header'` instead of `Header` in response, so it's possible to provide
`Description`, for example:

```
type PaginationTotalCountHeader =
  Header'
    '[ Description "Indicates to the client total count of items in collection"
     , Optional
     , Strict
     ]
    "Total-Count"
    Int
```

Note: if you want to add header with description you should use `addHeader'`
or `noHeader'` which accepts `Header'` with all modifiers.
  • Loading branch information
worm2fed authored Aug 4, 2023
1 parent 02242e9 commit 72f5d5c
Show file tree
Hide file tree
Showing 10 changed files with 107 additions and 49 deletions.
23 changes: 23 additions & 0 deletions changelog.d/full-header-type
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
synopsis: Use `Header'` in response headers.
prs: #1697

description: {

Use `Header'` instead of `Header` in response, so it's possible to provide
`Description`, for example:

```
type PaginationTotalCountHeader =
Header'
'[ Description "Indicates to the client total count of items in collection"
, Optional
, Strict
]
"Total-Count"
Int
```

Note: if you want to add header with description you should use `addHeader'`
or `noHeader'` which accepts `Header'` with all modifiers.

}
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@ module Servant.Auth.Server.Internal.AddSetCookie where

import Blaze.ByteString.Builder (toByteString)
import qualified Data.ByteString as BS
import Data.Tagged (Tagged (..))
import qualified Network.HTTP.Types as HTTP
import Network.Wai (mapResponseHeaders)
import Servant
import Servant.API.UVerb.Union
import Servant.API.Generic
import Servant.Server.Generic
import Web.Cookie
Expand Down Expand Up @@ -76,12 +74,12 @@ instance (orig1 ~ orig2) => AddSetCookies 'Z orig1 orig2 where
instance {-# OVERLAPPABLE #-}
( Functor m
, AddSetCookies n (m old) (m cookied)
, AddHeader "Set-Cookie" SetCookie cookied new
, AddHeader mods "Set-Cookie" SetCookie cookied new
) => AddSetCookies ('S n) (m old) (m new) where
addSetCookies (mCookie `SetCookieCons` rest) oldVal =
case mCookie of
Nothing -> noHeader <$> addSetCookies rest oldVal
Just cookie -> addHeader cookie <$> addSetCookies rest oldVal
Nothing -> noHeader' <$> addSetCookies rest oldVal
Just cookie -> addHeader' cookie <$> addSetCookies rest oldVal

instance {-# OVERLAPS #-}
(AddSetCookies ('S n) a a', AddSetCookies ('S n) b b')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import Blaze.ByteString.Builder (toByteString)
import Control.Monad (MonadPlus(..), guard)
import Control.Monad.Except
import Control.Monad.Reader
import qualified Crypto.JOSE as Jose
import qualified Crypto.JWT as Jose
import Data.ByteArray (constEq)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Base64 as BS64
Expand All @@ -18,11 +16,11 @@ import Data.Time.Clock (UTCTime(..), secondsToDiffTime)
import Network.HTTP.Types (methodGet)
import Network.HTTP.Types.Header(hCookie)
import Network.Wai (Request, requestHeaders, requestMethod)
import Servant (AddHeader, addHeader)
import Servant (AddHeader, addHeader')
import System.Entropy (getEntropy)
import Web.Cookie

import Servant.Auth.JWT (FromJWT (decodeJWT), ToJWT)
import Servant.Auth.JWT (FromJWT, ToJWT)
import Servant.Auth.Server.Internal.ConfigTypes
import Servant.Auth.Server.Internal.JWT (makeJWT, verifyJWT)
import Servant.Auth.Server.Internal.Types
Expand Down Expand Up @@ -132,8 +130,8 @@ applySessionCookieSettings cookieSettings setCookie = setCookie
-- provided response object with XSRF and session cookies. This should be used
-- when a user successfully authenticates with credentials.
acceptLogin :: ( ToJWT session
, AddHeader "Set-Cookie" SetCookie response withOneCookie
, AddHeader "Set-Cookie" SetCookie withOneCookie withTwoCookies )
, AddHeader mods "Set-Cookie" SetCookie response withOneCookie
, AddHeader mods "Set-Cookie" SetCookie withOneCookie withTwoCookies )
=> CookieSettings
-> JWTSettings
-> session
Expand All @@ -144,20 +142,20 @@ acceptLogin cookieSettings jwtSettings session = do
Nothing -> pure Nothing
Just sessionCookie -> do
xsrfCookie <- makeXsrfCookie cookieSettings
return $ Just $ addHeader sessionCookie . addHeader xsrfCookie
return $ Just $ addHeader' sessionCookie . addHeader' xsrfCookie

-- | Arbitrary cookie expiry time set back in history after unix time 0
expireTime :: UTCTime
expireTime = UTCTime (ModifiedJulianDay 50000) 0

-- | Adds headers to a response that clears all session cookies
-- | using max-age and expires cookie attributes.
clearSession :: ( AddHeader "Set-Cookie" SetCookie response withOneCookie
, AddHeader "Set-Cookie" SetCookie withOneCookie withTwoCookies )
clearSession :: ( AddHeader mods "Set-Cookie" SetCookie response withOneCookie
, AddHeader mods "Set-Cookie" SetCookie withOneCookie withTwoCookies )
=> CookieSettings
-> response
-> withTwoCookies
clearSession cookieSettings = addHeader clearedSessionCookie . addHeader clearedXsrfCookie
clearSession cookieSettings = addHeader' clearedSessionCookie . addHeader' clearedXsrfCookie
where
-- According to RFC6265 max-age takes precedence, but IE/Edge ignore it completely so we set both
cookieSettingsExpires = cookieSettings
Expand Down
14 changes: 10 additions & 4 deletions servant-server/test/Servant/ServerSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ import Network.Wai.Test
import Servant.API
((:<|>) (..), (:>), AuthProtect, BasicAuth,
BasicAuthData (BasicAuthData), Capture, Capture', CaptureAll,
Delete, EmptyAPI, Fragment, Get, HasStatus (StatusOf), Header,
Headers, HttpVersion, IsSecure (..), JSON, Lenient,
NoContent (..), NoContentVerb, NoFraming, OctetStream, Patch,
Delete, Description, EmptyAPI, Fragment, Get, HasStatus (StatusOf),
Header, Header', Headers, HttpVersion, IsSecure (..), JSON, Lenient,
NoContent (..), NoContentVerb, NoFraming, OctetStream, Optional, Patch,
PlainText, Post, Put, QueryFlag, QueryParam, QueryParams, Raw, RawM,
RemoteHost, ReqBody, SourceIO, StdMethod (..), Stream, Strict,
UVerb, Union, Verb, WithStatus (..), addHeader)
UVerb, Union, Verb, WithStatus (..), addHeader, addHeader')
import Servant.Server
(Context ((:.), EmptyContext), Handler, Server, ServerT, Tagged (..),
emptyServer, err401, err403, err404, hoistServer, respond, serve,
Expand Down Expand Up @@ -121,6 +121,7 @@ type VerbApi method status
:<|> "noContent" :> NoContentVerb method
:<|> "header" :> Verb method status '[JSON] (Headers '[Header "H" Int] Person)
:<|> "headerNC" :> Verb method status '[JSON] (Headers '[Header "H" Int] NoContent)
:<|> "headerD" :> Verb method status '[JSON] (Headers '[Header' '[Description "desc", Optional, Strict] "H" Int] Person)
:<|> "accept" :> ( Verb method status '[JSON] Person
:<|> Verb method status '[PlainText] String
)
Expand All @@ -133,6 +134,7 @@ verbSpec = describe "Servant.API.Verb" $ do
:<|> return NoContent
:<|> return (addHeader 5 alice)
:<|> return (addHeader 10 NoContent)
:<|> return (addHeader' 5 alice)
:<|> (return alice :<|> return "B")
:<|> return (S.source ["bytestring"])

Expand Down Expand Up @@ -177,6 +179,10 @@ verbSpec = describe "Servant.API.Verb" $ do
liftIO $ statusCode (simpleStatus response2) `shouldBe` status
liftIO $ simpleHeaders response2 `shouldContain` [("H", "5")]

response3 <- THW.request method "/headerD" [] ""
liftIO $ statusCode (simpleStatus response3) `shouldBe` status
liftIO $ simpleHeaders response3 `shouldContain` [("H", "5")]

it "handles trailing '/' gracefully" $ do
response <- THW.request method "/headerNC/" [] ""
liftIO $ statusCode (simpleStatus response) `shouldBe` status
Expand Down
10 changes: 7 additions & 3 deletions servant-swagger/src/Servant/Swagger/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ import Network.HTTP.Media (MediaType)
import Servant.API
import Servant.API.Description (FoldDescription,
reflectDescription)
import Servant.API.Generic (ToServantApi, AsApi)
import Servant.API.Modifiers (FoldRequired)

import Servant.Swagger.Internal.TypeLevel.API
Expand Down Expand Up @@ -470,10 +469,15 @@ instance (Accept c, AllAccept cs) => AllAccept (c ': cs) where
class ToResponseHeader h where
toResponseHeader :: Proxy h -> (HeaderName, Swagger.Header)

instance (KnownSymbol sym, ToParamSchema a) => ToResponseHeader (Header sym a) where
toResponseHeader _ = (hname, Swagger.Header Nothing hschema)
instance (KnownSymbol sym, ToParamSchema a, KnownSymbol (FoldDescription mods)) => ToResponseHeader (Header' mods sym a) where
toResponseHeader _ =
( hname
, Swagger.Header (transDesc $ reflectDescription (Proxy :: Proxy mods)) hschema
)
where
hname = Text.pack (symbolVal (Proxy :: Proxy sym))
transDesc "" = Nothing
transDesc desc = Just (Text.pack desc)
hschema = toParamSchema (Proxy :: Proxy a)

class AllToResponseHeader hs where
Expand Down
5 changes: 3 additions & 2 deletions servant/src/Servant/API.hs
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,9 @@ import Servant.API.ReqBody
import Servant.API.ResponseHeaders
(AddHeader, BuildHeadersTo (buildHeadersTo),
GetHeaders (getHeaders), HList (..), HasResponseHeader,
Headers (..), ResponseHeader (..), addHeader, getHeadersHList,
getResponse, lookupResponseHeader, noHeader)
Headers (..), ResponseHeader (..), addHeader, addHeader',
getHeadersHList, getResponse, lookupResponseHeader, noHeader,
noHeader')
import Servant.API.Stream
(FramingRender (..), FramingUnrender (..), FromSourceIO (..),
NetstringFraming, NewlineFraming, NoFraming, SourceIO, Stream,
Expand Down
50 changes: 31 additions & 19 deletions servant/src/Servant/API/ResponseHeaders.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ module Servant.API.ResponseHeaders
, ResponseHeader (..)
, AddHeader
, addHeader
, addHeader'
, noHeader
, noHeader'
, HasResponseHeader
, lookupResponseHeader
, BuildHeadersTo(buildHeadersTo)
Expand All @@ -37,7 +39,7 @@ module Servant.API.ResponseHeaders
import Control.DeepSeq
(NFData (..))
import Data.ByteString.Char8 as BS
(ByteString, init, pack, unlines)
(ByteString, pack)
import qualified Data.CaseInsensitive as CI
import qualified Data.List as L
import Data.Proxy
Expand All @@ -52,7 +54,9 @@ import Web.HttpApiData
import Prelude ()
import Prelude.Compat
import Servant.API.Header
(Header)
(Header')
import Servant.API.Modifiers
(Optional, Strict)
import Servant.API.UVerb.Union
import qualified Data.SOP.BasicFunctors as SOP
import qualified Data.SOP.NS as SOP
Expand Down Expand Up @@ -81,19 +85,19 @@ instance NFData a => NFData (ResponseHeader sym a) where

data HList a where
HNil :: HList '[]
HCons :: ResponseHeader h x -> HList xs -> HList (Header h x ': xs)
HCons :: ResponseHeader h x -> HList xs -> HList (Header' mods h x ': xs)

class NFDataHList xs where rnfHList :: HList xs -> ()
instance NFDataHList '[] where rnfHList HNil = ()
instance (y ~ Header h x, NFData x, NFDataHList xs) => NFDataHList (y ': xs) where
instance (y ~ Header' mods h x, NFData x, NFDataHList xs) => NFDataHList (y ': xs) where
rnfHList (HCons h xs) = rnf h `seq` rnfHList xs

instance NFDataHList xs => NFData (HList xs) where
rnf = rnfHList

type family HeaderValMap (f :: * -> *) (xs :: [*]) where
HeaderValMap f '[] = '[]
HeaderValMap f (Header h x ': xs) = Header h (f x) ': HeaderValMap f xs
HeaderValMap f (Header' mods h x ': xs) = Header' mods h (f x) ': HeaderValMap f xs


class BuildHeadersTo hs where
Expand All @@ -105,7 +109,7 @@ instance {-# OVERLAPPING #-} BuildHeadersTo '[] where
-- The current implementation does not manipulate HTTP header field lines in any way,
-- like merging field lines with the same field name in a single line.
instance {-# OVERLAPPABLE #-} ( FromHttpApiData v, BuildHeadersTo xs, KnownSymbol h )
=> BuildHeadersTo (Header h v ': xs) where
=> BuildHeadersTo (Header' mods h v ': xs) where
buildHeadersTo headers = case L.find wantedHeader headers of
Nothing -> MissingHeader `HCons` buildHeadersTo headers
Just header@(_, val) -> case parseHeader val of
Expand All @@ -130,7 +134,7 @@ instance GetHeadersFromHList '[] where
getHeadersFromHList _ = []

instance (KnownSymbol h, ToHttpApiData x, GetHeadersFromHList xs)
=> GetHeadersFromHList (Header h x ': xs)
=> GetHeadersFromHList (Header' mods h x ': xs)
where
getHeadersFromHList hdrs = case hdrs of
Header val `HCons` rest -> (headerName , toHeader val) : getHeadersFromHList rest
Expand All @@ -151,42 +155,42 @@ instance GetHeaders' '[] where
getHeaders' _ = []

instance (KnownSymbol h, GetHeadersFromHList rest, ToHttpApiData v)
=> GetHeaders' (Header h v ': rest)
=> GetHeaders' (Header' mods h v ': rest)
where
getHeaders' hs = getHeadersFromHList $ getHeadersHList hs

-- * Adding headers

-- We need all these fundeps to save type inference
class AddHeader h v orig new
| h v orig -> new, new -> h, new -> v, new -> orig where
class AddHeader (mods :: [*]) h v orig new
| mods h v orig -> new, new -> mods, new -> h, new -> v, new -> orig where
addOptionalHeader :: ResponseHeader h v -> orig -> new -- ^ N.B.: The same header can't be added multiple times

-- In this instance, we add a Header on top of something that is already decorated with some headers
instance {-# OVERLAPPING #-} ( KnownSymbol h, ToHttpApiData v )
=> AddHeader h v (Headers (fst ': rest) a) (Headers (Header h v ': fst ': rest) a) where
=> AddHeader mods h v (Headers (fst ': rest) a) (Headers (Header' mods h v ': fst ': rest) a) where
addOptionalHeader hdr (Headers resp heads) = Headers resp (HCons hdr heads)

-- In this instance, 'a' parameter is decorated with a Header.
instance {-# OVERLAPPABLE #-} ( KnownSymbol h, ToHttpApiData v , new ~ Headers '[Header h v] a)
=> AddHeader h v a new where
instance {-# OVERLAPPABLE #-} ( KnownSymbol h, ToHttpApiData v , new ~ Headers '[Header' mods h v] a)
=> AddHeader mods h v a new where
addOptionalHeader hdr resp = Headers resp (HCons hdr HNil)

-- Instances to decorate all responses in a 'Union' with headers. The functional
-- dependencies force us to consider singleton lists as the base case in the
-- recursion (it is impossible to determine h and v otherwise from old / new
-- responses if the list is empty).
instance (AddHeader h v old new) => AddHeader h v (Union '[old]) (Union '[new]) where
instance (AddHeader mods h v old new) => AddHeader mods h v (Union '[old]) (Union '[new]) where
addOptionalHeader hdr resp =
SOP.Z $ SOP.I $ addOptionalHeader hdr $ SOP.unI $ SOP.unZ $ resp

instance
( AddHeader h v old new, AddHeader h v (Union oldrest) (Union newrest)
( AddHeader mods h v old new, AddHeader mods h v (Union oldrest) (Union newrest)
-- This ensures that the remainder of the response list is _not_ empty
-- It is necessary to prevent the two instances for union types from
-- overlapping.
, oldrest ~ (a ': as), newrest ~ (b ': bs))
=> AddHeader h v (Union (old ': (a ': as))) (Union (new ': (b ': bs))) where
=> AddHeader mods h v (Union (old ': (a ': as))) (Union (new ': (b ': bs))) where
addOptionalHeader hdr resp = case resp of
SOP.Z (SOP.I rHead) -> SOP.Z $ SOP.I $ addOptionalHeader hdr rHead
SOP.S rOthers -> SOP.S $ addOptionalHeader hdr rOthers
Expand All @@ -211,21 +215,29 @@ instance
-- Note that while in your handlers type annotations are not required, since
-- the type can be inferred from the API type, in other cases you may find
-- yourself needing to add annotations.
addHeader :: AddHeader h v orig new => v -> orig -> new
addHeader :: AddHeader '[Optional, Strict] h v orig new => v -> orig -> new
addHeader = addOptionalHeader . Header

-- | Same as 'addHeader' but works with `Header'`, so it's possible to use any @mods@.
addHeader' :: AddHeader mods h v orig new => v -> orig -> new
addHeader' = addOptionalHeader . Header

-- | Deliberately do not add a header to a value.
--
-- >>> let example1 = noHeader "hi" :: Headers '[Header "someheader" Int] String
-- >>> getHeaders example1
-- []
noHeader :: AddHeader h v orig new => orig -> new
noHeader :: AddHeader '[Optional, Strict] h v orig new => orig -> new
noHeader = addOptionalHeader MissingHeader

-- | Same as 'noHeader' but works with `Header'`, so it's possible to use any @mods@.
noHeader' :: AddHeader mods h v orig new => orig -> new
noHeader' = addOptionalHeader MissingHeader

class HasResponseHeader h a headers where
hlistLookupHeader :: HList headers -> ResponseHeader h a

instance {-# OVERLAPPING #-} HasResponseHeader h a (Header h a ': rest) where
instance {-# OVERLAPPING #-} HasResponseHeader h a (Header' mods h a ': rest) where
hlistLookupHeader (HCons ha _) = ha

instance {-# OVERLAPPABLE #-} (HasResponseHeader h a rest) => HasResponseHeader h a (first ': rest) where
Expand Down
3 changes: 2 additions & 1 deletion servant/src/Servant/API/TypeLevel.hs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ import Servant.API.Capture
(Capture, CaptureAll)
import Servant.API.Fragment
import Servant.API.Header
(Header)
(Header, Header')
import Servant.API.QueryParam
(QueryFlag, QueryParam, QueryParams)
import Servant.API.ReqBody
Expand Down Expand Up @@ -130,6 +130,7 @@ type family IsElem endpoint api :: Constraint where
IsElem e (sa :<|> sb) = Or (IsElem e sa) (IsElem e sb)
IsElem (e :> sa) (e :> sb) = IsElem sa sb
IsElem sa (Header sym x :> sb) = IsElem sa sb
IsElem sa (Header' mods sym x :> sb) = IsElem sa sb
IsElem sa (ReqBody y x :> sb) = IsElem sa sb
IsElem (CaptureAll z y :> sa) (CaptureAll x y :> sb)
= IsElem sa sb
Expand Down
8 changes: 8 additions & 0 deletions servant/test/Servant/API/ResponseHeadersSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@ import GHC.TypeLits
import Test.Hspec

import Servant.API.ContentTypes
import Servant.API.Description
(Description)
import Servant.API.Header
import Servant.API.Modifiers
(Optional, Strict)
import Servant.API.ResponseHeaders
import Servant.API.UVerb

Expand All @@ -27,6 +31,10 @@ spec = describe "Servant.API.ResponseHeaders" $ do
let val = addHeader 10 $ addHeader "b" 5 :: Headers '[Header "first" Int, Header "second" String] Int
getHeaders val `shouldBe` [("first", "10"), ("second", "b")]

it "adds a header with description to a value" $ do
let val = addHeader' "hi" 5 :: Headers '[Header' '[Description "desc", Optional, Strict] "test" String] Int
getHeaders val `shouldBe` [("test", "hi")]

describe "noHeader" $ do

it "does not add a header" $ do
Expand Down
Loading

0 comments on commit 72f5d5c

Please sign in to comment.