150 lines
3.6 KiB
Go
150 lines
3.6 KiB
Go
|
// Package token handles JWT tokens manipulation
|
||
|
package token
|
||
|
|
||
|
import (
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"net/http"
|
||
|
|
||
|
"github.com/gin-gonic/gin"
|
||
|
jwt "github.com/golang-jwt/jwt/v5"
|
||
|
"go.uber.org/zap"
|
||
|
)
|
||
|
|
||
|
// Authenticater is the interface that wraps the Authenticate method
|
||
|
type Authenticater interface {
|
||
|
Authenticate(login, pass string) (CustomClaims, error)
|
||
|
}
|
||
|
|
||
|
// CustomClaims is the struct that represents the claims of a JWT token in EPFL context
|
||
|
type CustomClaims struct {
|
||
|
Sciper string `json:"sciper"`
|
||
|
jwt.RegisteredClaims
|
||
|
}
|
||
|
|
||
|
// Validate validates the claims of a JWT token
|
||
|
func (m CustomClaims) Validate() error {
|
||
|
if m.Sciper == "" {
|
||
|
return errors.New("sciper must be set")
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Token is the struct that represents a JWT token
|
||
|
type Token struct {
|
||
|
JWT *jwt.Token
|
||
|
}
|
||
|
|
||
|
// New creates a new JWT token
|
||
|
func New(claims CustomClaims) *Token {
|
||
|
jwt := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||
|
|
||
|
return &Token{JWT: jwt}
|
||
|
}
|
||
|
|
||
|
// Parse parses a JWT token
|
||
|
func Parse(tokenString string, secret []byte) (*Token, error) {
|
||
|
t, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||
|
// Don't forget to validate the alg is what you expect:
|
||
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||
|
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
|
||
|
}
|
||
|
|
||
|
return secret, nil
|
||
|
})
|
||
|
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
return &Token{t}, nil
|
||
|
}
|
||
|
|
||
|
// Sign signs a JWT token
|
||
|
func (t *Token) Sign(secret []byte) (string, error) {
|
||
|
return t.JWT.SignedString([]byte(secret))
|
||
|
}
|
||
|
|
||
|
// Claims returns the claims of a JWT token
|
||
|
func (t *Token) Claims() jwt.MapClaims {
|
||
|
return t.JWT.Claims.(jwt.MapClaims)
|
||
|
}
|
||
|
|
||
|
// Set sets a claim in a JWT token
|
||
|
func (t *Token) Set(key string, value interface{}) {
|
||
|
t.Claims()[key] = value
|
||
|
}
|
||
|
|
||
|
// Get gets a claim from a JWT token
|
||
|
func (t *Token) Get(key string) interface{} {
|
||
|
return t.Claims()[key]
|
||
|
}
|
||
|
|
||
|
// GetString gets a claim from a JWT token as a string
|
||
|
func (t *Token) GetString(key string) string {
|
||
|
return t.Claims()[key].(string)
|
||
|
}
|
||
|
|
||
|
// ToJSON converts a JWT token to JSON
|
||
|
func (t *Token) ToJSON() (string, error) {
|
||
|
return t.JWT.Raw, nil
|
||
|
}
|
||
|
|
||
|
// PostLoginHandler is the handler that checks the login and password and returns a JWT token
|
||
|
func PostLoginHandler(log *zap.Logger, auth Authenticater, secret []byte) gin.HandlerFunc {
|
||
|
log.Info("Creating login handler")
|
||
|
return func(c *gin.Context) {
|
||
|
login := c.PostForm("login")
|
||
|
pass := c.PostForm("pass")
|
||
|
|
||
|
log.Info("Login attempt", zap.String("login", login))
|
||
|
|
||
|
claims, err := auth.Authenticate(login, pass)
|
||
|
if err != nil {
|
||
|
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||
|
return
|
||
|
}
|
||
|
|
||
|
t := New(claims)
|
||
|
encoded, err := t.Sign(secret)
|
||
|
if err != nil {
|
||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
|
return
|
||
|
}
|
||
|
|
||
|
c.JSON(http.StatusOK, gin.H{"access_token": encoded})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// GinMiddleware is the middleware that checks the JWT token
|
||
|
func GinMiddleware(secret []byte) gin.HandlerFunc {
|
||
|
return func(c *gin.Context) {
|
||
|
authorizationHeaderString := c.GetHeader("Authorization")
|
||
|
if authorizationHeaderString == "" {
|
||
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "No token provided"})
|
||
|
c.Abort()
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// Check that the authorization header starts with "Bearer"
|
||
|
if len(authorizationHeaderString) < 7 || authorizationHeaderString[:7] != "Bearer " {
|
||
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
|
||
|
c.Abort()
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// Extract the token from the authorization header
|
||
|
tokenString := authorizationHeaderString[7:]
|
||
|
|
||
|
t, err := Parse(tokenString, secret)
|
||
|
if err != nil {
|
||
|
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||
|
c.Abort()
|
||
|
return
|
||
|
}
|
||
|
|
||
|
c.Set("token", t)
|
||
|
c.Next()
|
||
|
}
|
||
|
}
|