Skip to content

Commit

Permalink
Introduce ping in sse stream
Browse files Browse the repository at this point in the history
This prevents e.g. log subscription from closing when there hasn't been
any new lines within 60 seconds.

Co-authored-by: Roger Bjørnstad <[email protected]>
Co-authored-by: Christer Edvartsen <[email protected]>
  • Loading branch information
3 people committed Mar 6, 2024
1 parent fea4ff6 commit d2beb89
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 1 deletion.
2 changes: 1 addition & 1 deletion internal/graph/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ func NewHandler(config gengql.Config, log logrus.FieldLogger) (*handler.Server,
schema := gengql.NewExecutableSchema(config)
graphHandler := handler.New(schema)
graphHandler.Use(metricsMiddleware)
graphHandler.AddTransport(transport.SSE{}) // Support subscriptions
graphHandler.AddTransport(SSE{}) // Support subscriptions
graphHandler.AddTransport(transport.Options{})
graphHandler.AddTransport(transport.POST{})
graphHandler.SetQueryCache(lru.New(1000))
Expand Down
161 changes: 161 additions & 0 deletions internal/graph/sse_transport.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
package graph

// This is a copy of github.com/99designs/gqlgen/graphql/handler/transport.SSE
// but with ping support.

import (
"encoding/json"
"fmt"
"io"
"log"
"mime"
"net/http"
"strings"
"sync"
"time"

"github.com/99designs/gqlgen/graphql"
"github.com/99designs/gqlgen/graphql/handler/transport"
"github.com/vektah/gqlparser/v2/gqlerror"
)

type SSE struct{}

var _ graphql.Transport = SSE{}

func (t SSE) Supports(r *http.Request) bool {
if !strings.Contains(r.Header.Get("Accept"), "text/event-stream") {
return false
}
mediaType, _, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
if err != nil {
return false
}
return r.Method == http.MethodPost && mediaType == "application/json"
}

func (t SSE) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
ctx := r.Context()
flusher, ok := w.(http.Flusher)
if !ok {
transport.SendErrorf(w, http.StatusInternalServerError, "streaming unsupported")
return
}
defer flusher.Flush()

w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("Content-Type", "application/json")

params := &graphql.RawParams{}
start := graphql.Now()
params.Headers = r.Header
params.ReadTime = graphql.TraceTiming{
Start: start,
End: graphql.Now(),
}

bodyString, err := getRequestBody(r)
if err != nil {
gqlErr := gqlerror.Errorf("could not get json request body: %+v", err)
resp := exec.DispatchError(ctx, gqlerror.List{gqlErr})
log.Printf("could not get json request body: %+v", err.Error())
writeJson(w, resp)
return
}

bodyReader := io.NopCloser(strings.NewReader(bodyString))
if err = jsonDecode(bodyReader, &params); err != nil {
w.WriteHeader(http.StatusBadRequest)
gqlErr := gqlerror.Errorf(
"json request body could not be decoded: %+v body:%s",
err,
bodyString,
)
resp := exec.DispatchError(ctx, gqlerror.List{gqlErr})
log.Printf("decoding error: %+v body:%s", err.Error(), bodyString)
writeJson(w, resp)
return
}

rc, opErr := exec.CreateOperationContext(ctx, params)
ctx = graphql.WithOperationContext(ctx, rc)

w.Header().Set("Content-Type", "text/event-stream")
fmt.Fprint(w, ":\n\n")
flusher.Flush()

if opErr != nil {
resp := exec.DispatchError(ctx, opErr)
writeJsonWithSSE(w, resp)
} else {
lock := &sync.Mutex{}
lastMessage := time.Now()
responses, ctx := exec.DispatchOperation(ctx, rc)

go func() {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
lock.Lock()
if time.Since(lastMessage) > 30*time.Second {
fmt.Fprint(w, ":\n\n")
flusher.Flush()
}
lock.Unlock()
case <-ctx.Done():
return
}
}
}()

for {
response := responses(ctx)
if response == nil {
break
}
lock.Lock()
lastMessage = time.Now()
writeJsonWithSSE(w, response)
flusher.Flush()
lock.Unlock()
}
}

fmt.Fprint(w, "event: complete\n\n")
}

func writeJsonWithSSE(w io.Writer, response *graphql.Response) {
b, err := json.Marshal(response)
if err != nil {
panic(err)
}
fmt.Fprintf(w, "event: next\ndata: %s\n\n", b)
}

func getRequestBody(r *http.Request) (string, error) {
if r == nil || r.Body == nil {
return "", nil
}
body, err := io.ReadAll(r.Body)
if err != nil {
return "", fmt.Errorf("unable to get Request Body %w", err)
}
return string(body), nil
}

func writeJson(w io.Writer, response *graphql.Response) {
b, err := json.Marshal(response)
if err != nil {
panic(err)
}
w.Write(b)
}

func jsonDecode(r io.Reader, val interface{}) error {
dec := json.NewDecoder(r)
dec.UseNumber()
return dec.Decode(val)
}

0 comments on commit d2beb89

Please sign in to comment.