diff --git a/fakecas.go b/fakecas.go index 72f41f6..a53624a 100644 --- a/fakecas.go +++ b/fakecas.go @@ -1,168 +1,43 @@ package main import ( - "encoding/json" - "encoding/xml" "flag" "fmt" + "github.com/labstack/echo" + mw "github.com/labstack/echo/middleware" "gopkg.in/mgo.v2" - "gopkg.in/mgo.v2/bson" - "log" - "net/http" - "net/url" - "strings" ) -type OAuthAttributes struct { - LastName string `json:"lastName"` - FirstName string `json:"firstName"` -} - -type OAuthResponse struct { - Id string `json:"id"` - Attributes OAuthAttributes `json:"attributes"` -} - -type User struct { - Id string `bson:"_id"` - Username string `bson:"username"` - Emails []string `bson:"emails"` - Fullname string `bson:"fullname"` - GivenName string `bson:"given_name"` - FamilyName string `bson:"family_name"` -} - -type ServiceResponse struct { - Xmlns string `xml:"xmlns:cas,attr"` - XMLName xml.Name `xml:"cas:serviceResponse"` - User string `xml:"cas:authenticationSuccess>cas:user"` - NewLogin bool `xml:"cas:authenticationSuccess>cas:attributes>cas:isFromNewLogin"` - Date string `xml:"cas:authenticationSuccess>cas:attributes>cas:authenticationDate"` - GivenName string `xml:"cas:authenticationSuccess>cas:attributes>cas:givenName"` - FamilyName string `xml:"cas:authenticationSuccess>cas:attributes>cas:familyName"` - LongTermAuth bool `xml:"cas:authenticationSuccess>cas:attributes>cas:longTermAuthenticationRequestTokenUsed"` - AccessToken string `xml:"cas:authenticationSuccess>cas:attributes>accessToken"` - UserName string `xml:"cas:authenticationSuccess>cas:attributes>username"` -} - var ( - host = flag.String("host", "localhost:8080", "The host to bind to") - databasename = flag.String("dbname", "osf20130903", "The name of your OSF database") - databaseaddress = flag.String("dbaddress", "localhost:27017", "The address of your mongodb. ie: localhost:27017") + Host = flag.String("host", "localhost:8080", "The host to bind to") + DatabaseName = flag.String("dbname", "osf20130903", "The name of your OSF database") + DatabaseAddress = flag.String("dbaddress", "localhost:27017", "The address of your mongodb. ie: localhost:27017") + DatabaseSession mgo.Session + UserCollection *mgo.Collection ) func main() { flag.Parse() + e := echo.New() + e.Use(mw.Logger()) + e.Use(mw.Recover()) + e.Use(CorsMiddleWare()) - http.HandleFunc("/login", login) - http.HandleFunc("/logout", logout) - http.HandleFunc("/oauth2/profile", oauth) - http.HandleFunc("/p3/serviceValidate", serviceValidate) - - fmt.Println("Expecting database", *databasename, " to be running at", *databaseaddress) - fmt.Println("Listening on", *host) + e.Post("/login", Login) + e.Get("/logout", Logout) + e.Get("/oauth2/profile", OAuth) + e.Get("/p3/serviceValidate", ServiceValidate) - err := http.ListenAndServe(*host, nil) + fmt.Println("Expecting database", *DatabaseName, " to be running at", *DatabaseAddress) + fmt.Println("Listening on", *Host) - if err != nil { - log.Fatal("ListenAndServe: ", err) - } -} - -func login(w http.ResponseWriter, r *http.Request) { - w.Header().Add("Access-Control-Allow-Origin", "*") - redir, err := url.Parse(r.FormValue("service")) - - if err != nil { - log.Fatal(err) - } - - query := redir.Query() - query.Set("ticket", r.FormValue("username")) - redir.RawQuery = query.Encode() - - fmt.Println("Logging in and redirecting to", redir) - http.Redirect(w, r, redir.String(), http.StatusFound) -} - -func logout(w http.ResponseWriter, r *http.Request) { - w.Header().Add("Access-Control-Allow-Origin", "*") - fmt.Println("Logging out and redirecting to", r.FormValue("service")) - http.Redirect(w, r, r.FormValue("service"), http.StatusFound) -} - -func serviceValidate(w http.ResponseWriter, r *http.Request) { - w.Header().Add("Access-Control-Allow-Origin", "*") - - session, err := mgo.Dial(*databaseaddress) + DatabaseSession, err := mgo.Dial(*DatabaseAddress) if err != nil { panic(err) } - defer session.Close() - - c := session.DB(*databasename).C("user") - - result := User{} - err = c.Find(bson.M{"emails": r.FormValue("ticket")}).One(&result) + defer DatabaseSession.Close() - if err != nil { - fmt.Println("User", r.FormValue("ticket"), "not found.") - http.NotFound(w, r) - return - } - - response := ServiceResponse{ - Xmlns: "http://www.yale.edu/tp/cas", - User: result.Id, - NewLogin: true, - Date: "Eh", - GivenName: result.GivenName, - FamilyName: result.FamilyName, - AccessToken: result.Id, - UserName: result.Username, - } - - x, err := xml.MarshalIndent(response, "", " ") - if err != nil { - panic(err) - } - - w.Header().Set("Content-Type", "application/xml") - w.Write(x) -} - -func oauth(w http.ResponseWriter, r *http.Request) { - w.Header().Add("Access-Control-Allow-Origin", "*") - - session, err := mgo.Dial(*databaseaddress) - if err != nil { - panic(err) - } - defer session.Close() - - c := session.DB(*databasename).C("user") - - result := User{} - err = c.Find(bson.M{"_id": strings.Replace(r.Header.Get("Authorization"), "Bearer ", "", 1)}).One(&result) - - if err != nil { - http.NotFound(w, r) - return - } - - js, err := json.Marshal(OAuthResponse{ - Id: result.Id, - Attributes: OAuthAttributes{ - LastName: result.FamilyName, - FirstName: result.GivenName, - }, - }) - - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } + UserCollection = DatabaseSession.DB(*DatabaseName).C("user") - w.Header().Set("Content-Type", "application/json") - w.Write(js) + e.Run(*Host) } diff --git a/middleware.go b/middleware.go new file mode 100644 index 0000000..6b88def --- /dev/null +++ b/middleware.go @@ -0,0 +1,21 @@ +package main + +import "github.com/labstack/echo" + +func CorsMiddleWare() echo.MiddlewareFunc { + return func(h echo.HandlerFunc) echo.HandlerFunc { + return func(c *echo.Context) error { + c.Response().Header().Add("Access-Control-Allow-Origin", "*") + c.Response().Header().Add("Access-Control-Allow-Headers", "Range, Content-Type, Authorization, Cache-Control, X-Requested-With") + c.Response().Header().Add("Access-Control-Expose-Headers", "Range, Content-Type, Authorization, Cache-Control, X-Requested-With") + c.Response().Header().Add("Cache-control", "no-store, no-cache, must-revalidate, max-age=0") + + if c.Request().Method == "OPTIONS" { + c.Response().Header().Add("Access-Control-Allow-Methods", "GET, PUT, POST, DELETE") + return c.NoContent(204) + } + h(c) + return nil + } + } +} diff --git a/types.go b/types.go new file mode 100644 index 0000000..6154872 --- /dev/null +++ b/types.go @@ -0,0 +1,35 @@ +package main + +import "encoding/xml" + +type OAuthAttributes struct { + LastName string `json:"lastName"` + FirstName string `json:"firstName"` +} + +type OAuthResponse struct { + Id string `json:"id"` + Attributes OAuthAttributes `json:"attributes"` +} + +type User struct { + Id string `bson:"_id"` + Username string `bson:"username"` + Emails []string `bson:"emails"` + Fullname string `bson:"fullname"` + GivenName string `bson:"given_name"` + FamilyName string `bson:"family_name"` +} + +type ServiceResponse struct { + Xmlns string `xml:"xmlns:cas,attr"` + XMLName xml.Name `xml:"cas:serviceResponse"` + User string `xml:"cas:authenticationSuccess>cas:user"` + NewLogin bool `xml:"cas:authenticationSuccess>cas:attributes>cas:isFromNewLogin"` + Date string `xml:"cas:authenticationSuccess>cas:attributes>cas:authenticationDate"` + GivenName string `xml:"cas:authenticationSuccess>cas:attributes>cas:givenName"` + FamilyName string `xml:"cas:authenticationSuccess>cas:attributes>cas:familyName"` + LongTermAuth bool `xml:"cas:authenticationSuccess>cas:attributes>cas:longTermAuthenticationRequestTokenUsed"` + AccessToken string `xml:"cas:authenticationSuccess>cas:attributes>accessToken"` + UserName string `xml:"cas:authenticationSuccess>cas:attributes>username"` +} diff --git a/views.go b/views.go new file mode 100644 index 0000000..f0f6d0a --- /dev/null +++ b/views.go @@ -0,0 +1,75 @@ +package main + +import ( + "fmt" + "github.com/labstack/echo" + "gopkg.in/mgo.v2/bson" + "net/http" + "net/url" + "strings" +) + +func Login(c *echo.Context) error { + redir, err := url.Parse(c.Form("service")) + + if err != nil { + c.Error(err) + return nil + } + + query := redir.Query() + query.Set("ticket", c.Form("username")) + redir.RawQuery = query.Encode() + + fmt.Println("Logging in and redirecting to", redir) + c.Redirect(http.StatusFound, redir.String()) + return nil +} + +func Logout(c *echo.Context) error { + fmt.Println("Logging out and redirecting to", c.Form("service")) + c.Redirect(http.StatusFound, c.Form("service")) + return nil +} + +func ServiceValidate(c *echo.Context) error { + result := User{} + err := UserCollection.Find(bson.M{"emails": c.Form("ticket")}).One(&result) + + if err != nil { + fmt.Println("User", c.Form("ticket"), "not found.") + return c.NoContent(http.StatusNotFound) + } + + response := ServiceResponse{ + Xmlns: "http://www.yale.edu/tp/cas", + User: result.Id, + NewLogin: true, + Date: "Eh", + GivenName: result.GivenName, + FamilyName: result.FamilyName, + AccessToken: result.Id, + UserName: result.Username, + } + + return c.XML(http.StatusOK, response) +} + +func OAuth(c *echo.Context) error { + result := User{} + err := UserCollection.Find(bson.M{ + "_id": strings.Replace(c.Request().Header.Get("Authorization"), "Bearer ", "", 1), + }).One(&result) + + if err != nil { + return c.NoContent(http.StatusNotFound) + } + + return c.JSON(200, OAuthResponse{ + Id: result.Id, + Attributes: OAuthAttributes{ + LastName: result.FamilyName, + FirstName: result.GivenName, + }, + }) +}