加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
jwt.go 5.59 KB
一键复制 编辑 原始数据 按行查看 历史
aesoper 提交于 2020-05-20 15:06 . 添加日志中间件
/**
* @Author: aesoper
* @Description:
* @File: jwt
* @Version: 1.0.0
* @Date: 2020/5/19 23:27
*/
package gin_middleware
import (
"errors"
"gitee.com/gin-ecosystem/gin-middleware/consts"
"github.com/dgrijalva/jwt-go"
"github.com/gin-gonic/gin"
"net/http"
"strings"
"time"
)
type (
JwtConfig struct {
Skipper Skipper
TokenLookup string
SigningMethod string
AuthScheme string
SigningKey string
// 过期时间 单位为分钟
Expire time.Duration
// 验证登陆用户身份是否合规
Authenticator func(ctx *gin.Context, claims jwt.MapClaims, token string) error
// 未登录时候的回调函数
Unauthorized func(ctx *gin.Context, httpCode int, err error)
}
jwtExtractor func(ctx *gin.Context) (string, error)
JwtEntity struct {
config JwtConfig
}
)
var (
ErrJWTMissing = errors.New("未找到token")
algorithmHS256 = "HS256"
DefaultJWTConfig = JwtConfig{
Skipper: DefaultSkipper,
SigningMethod: algorithmHS256,
TokenLookup: "header:" + consts.HeaderAuthorization,
AuthScheme: "Bearer",
SigningKey: "111111111111111111",
Expire: time.Hour,
Unauthorized: func(ctx *gin.Context, httpCode int, err error) {
ctx.Status(httpCode)
ctx.Abort()
},
}
)
func NewJwt(c JwtConfig) JwtEntity {
if c.Skipper == nil {
c.Skipper = DefaultJWTConfig.Skipper
}
if c.SigningMethod == "" {
c.SigningMethod = DefaultJWTConfig.SigningMethod
}
if c.TokenLookup == "" {
c.TokenLookup = DefaultJWTConfig.TokenLookup
}
if c.AuthScheme == "" {
c.AuthScheme = DefaultJWTConfig.AuthScheme
}
if c.Expire == 0 {
c.Expire = DefaultJWTConfig.Expire
}
if c.Unauthorized == nil {
c.Unauthorized = DefaultJWTConfig.Unauthorized
}
return JwtEntity{config: c}
}
func getSigningMethod(signingMethod string) *jwt.SigningMethodHMAC {
switch signingMethod {
case "HS256":
return jwt.SigningMethodHS256
case "HS384":
return jwt.SigningMethodHS384
case "HS512":
return jwt.SigningMethodHS512
}
return jwt.SigningMethodHS256
}
func (entity JwtEntity) Middleware() gin.HandlerFunc {
if entity.config.Unauthorized == nil {
entity.config.Unauthorized = DefaultJWTConfig.Unauthorized
}
if entity.config.Authenticator == nil {
panic("err")
}
return func(ctx *gin.Context) {
if entity.config.Skipper(ctx) {
ctx.Next()
return
}
claims, token, err := entity.validate(ctx)
if err != nil {
entity.config.Unauthorized(ctx, http.StatusUnauthorized, err)
return
}
// 判断是否已经在存储中查询到,查不到则说明未登录
if err := entity.config.Authenticator(ctx, claims, token); err != nil {
entity.config.Unauthorized(ctx, http.StatusUnauthorized, err)
ctx.Abort()
return
}
}
}
// 生成token
func (entity JwtEntity) GenerateJwtToken(claims jwt.MapClaims) (token string, expired int64, err error) {
if entity.config.SigningMethod == "" {
entity.config.SigningMethod = DefaultJWTConfig.SigningMethod
}
if entity.config.SigningKey == "" {
entity.config.SigningKey = DefaultJWTConfig.SigningKey
}
expired = time.Now().Add(entity.config.Expire * time.Duration(1)).Unix()
// jwt的签发时间
claims["iat"] = time.Now().Unix()
// 到期时间
claims["exp"] = expired
// 在此之前不可用
claims["nbf"] = time.Now().Unix()
jwtToken := jwt.NewWithClaims(getSigningMethod(entity.config.SigningMethod), claims)
signedString, err := jwtToken.SignedString([]byte(entity.config.SigningKey))
return signedString, expired, err
}
// 解析token
func (entity *JwtEntity) ResolveToken(tokenString string) (jwt.MapClaims, error) {
jwtToken, err := jwt.Parse(tokenString, func(*jwt.Token) (interface{}, error) {
return []byte(entity.config.SigningKey), nil
})
claims, ok := jwtToken.Claims.(jwt.MapClaims)
if ok {
return claims, claims.Valid()
}
return nil, err
}
func (entity *JwtEntity) validate(ctx *gin.Context) (jwt.MapClaims, string, error) {
extractor := getExtractor(entity.config)
tokenStr, err := extractor(ctx)
if err != nil {
return nil, "", err
}
if tokenStr == "" {
return nil, "", ErrJWTMissing
}
claims, err := entity.ResolveToken(tokenStr)
return claims, tokenStr, err
}
func (entity *JwtEntity) RefreshToken(ctx *gin.Context) (tokenStr string, exp int64, err error) {
claims, _, err := entity.validate(ctx)
if err != nil {
return "", 0, err
}
return entity.GenerateJwtToken(claims)
}
func getExtractor(config JwtConfig) jwtExtractor {
parts := strings.Split(config.TokenLookup, ":")
extractor := getJwtTokenFromHeader(parts[1], config.AuthScheme)
if len(parts) == 2 {
switch parts[0] {
case "query":
extractor = getJwtFromQuery(parts[1])
case "param":
extractor = getJwtFromParam(parts[1])
case "cookie":
extractor = getJwtFromCookie(parts[1])
}
}
return extractor
}
func getJwtTokenFromHeader(header, authScheme string) jwtExtractor {
return func(ctx *gin.Context) (string, error) {
auth := ctx.Request.Header.Get(header)
l := len(authScheme)
if len(auth) > l+1 && auth[:l] == authScheme {
return auth[l+1:], nil
}
return "", ErrJWTMissing
}
}
func getJwtFromQuery(param string) jwtExtractor {
return func(ctx *gin.Context) (string, error) {
token := ctx.Query(param)
if token == "" {
return "", ErrJWTMissing
}
return token, nil
}
}
func getJwtFromParam(param string) jwtExtractor {
return func(c *gin.Context) (string, error) {
token := c.Param(param)
if token == "" {
return "", ErrJWTMissing
}
return token, nil
}
}
func getJwtFromCookie(name string) jwtExtractor {
return func(c *gin.Context) (string, error) {
cookie, err := c.Cookie(name)
if err != nil {
return "", ErrJWTMissing
}
return cookie, nil
}
}
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化