Advanced JWT Authentication with Go Fiber: A Comprehensive Guide

Introduction

In our previous post, we explored the basics of implementing JWT authentication in a Go Fiber application. Today, we’ll dive deeper into advanced concepts and best practices to create a more robust and secure authentication system. We’ll cover topics such as refresh tokens, token revocation, rate limiting, and more.

Prerequisites

  • Basic knowledge of Go and Fiber
  • Familiarity with JWT concepts
  • Understanding of the content from our previous JWT auth post

Advanced JWT Implementation

1. Refresh Tokens

Refresh tokens allow us to issue short-lived access tokens while providing a seamless authentication experience. Here’s how to implement them:

import (
    "github.com/gofiber/fiber/v2"
    "github.com/golang-jwt/jwt/v4"
    "time"
)

type TokenPair struct {
    AccessToken  string
    RefreshToken string
}

func generateTokenPair(userID string) (TokenPair, error) {
    // Generate access token
    accessToken := jwt.New(jwt.SigningMethodHS256)
    accessClaims := accessToken.Claims.(jwt.MapClaims)
    accessClaims["user_id"] = userID
    accessClaims["exp"] = time.Now().Add(time.Minute * 15).Unix()

    accessTokenString, err := accessToken.SignedString([]byte(jwtSecret))
    if err != nil {
        return TokenPair{}, err
    }

    // Generate refresh token
    refreshToken := jwt.New(jwt.SigningMethodHS256)
    refreshClaims := refreshToken.Claims.(jwt.MapClaims)
    refreshClaims["user_id"] = userID
    refreshClaims["exp"] = time.Now().Add(time.Hour * 24 * 7).Unix()

    refreshTokenString, err := refreshToken.SignedString([]byte(jwtSecret))
    if err != nil {
        return TokenPair{}, err
    }

    return TokenPair{
        AccessToken:  accessTokenString,
        RefreshToken: refreshTokenString,
    }, nil
}

func refreshTokenHandler(c *fiber.Ctx) error {
    refreshToken := c.Get("Refresh-Token")
    if refreshToken == "" {
        return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Refresh token is required"})
    }

    token, err := jwt.Parse(refreshToken, func(token *jwt.Token) (interface{}, error) {
        return []byte(jwtSecret), nil
    })

    if err != nil || !token.Valid {
        return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Invalid refresh token"})
    }

    claims, ok := token.Claims.(jwt.MapClaims)
    if !ok {
        return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Failed to parse claims"})
    }

    userID, ok := claims["user_id"].(string)
    if !ok {
        return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Invalid user ID in token"})
    }

    newTokenPair, err := generateTokenPair(userID)
    if err != nil {
        return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Failed to generate new tokens"})
    }

    return c.JSON(newTokenPair)
}

2. Token Revocation

To implement token revocation, we’ll use a Redis cache to store revoked tokens:

import (
    "github.com/go-redis/redis/v8"
    "context"
    "time"
)

var redisClient *redis.Client

func initRedis() {
    redisClient = redis.NewClient(&redis.Options{
        Addr: "localhost:6379",
    })
}

func revokeToken(token string, expiration time.Duration) error {
    ctx := context.Background()
    return redisClient.Set(ctx, "revoked:"+token, true, expiration).Err()
}

func isTokenRevoked(token string) bool {
    ctx := context.Background()
    _, err := redisClient.Get(ctx, "revoked:"+token).Result()
    return err == nil
}

func verifyJWT(c *fiber.Ctx) error {
    token := c.Get("Authorization")
    if token == "" {
        return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Missing auth token"})
    }

    if isTokenRevoked(token) {
        return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "Token has been revoked"})
    }

    // Proceed with normal JWT verification...
}

func logoutHandler(c *fiber.Ctx) error {
    token := c.Get("Authorization")
    if token == "" {
        return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Missing auth token"})
    }

    // Parse the token to get its expiration time
    parsedToken, _ := jwt.Parse(token, func(token *jwt.Token) (interface{}, error) {
        return []byte(jwtSecret), nil
    })

    claims, ok := parsedToken.Claims.(jwt.MapClaims)
    if !ok {
        return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Failed to parse claims"})
    }

    exp, ok := claims["exp"].(float64)
    if !ok {
        return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Invalid expiration in token"})
    }

    expirationTime := time.Unix(int64(exp), 0)
    duration := time.Until(expirationTime)

    if err := revokeToken(token, duration); err != nil {
        return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Failed to revoke token"})
    }

    return c.SendStatus(fiber.StatusOK)
}

3. Rate Limiting

Implement rate limiting to protect your auth endpoints:

import (
    "github.com/gofiber/fiber/v2/middleware/limiter"
    "time"
)

func setupRateLimiter() fiber.Handler {
    return limiter.New(limiter.Config{
        Max:        5,
        Expiration: 1 * time.Minute,
        KeyGenerator: func(c *fiber.Ctx) string {
            return c.IP()
        },
        LimitReached: func(c *fiber.Ctx) error {
            return c.Status(fiber.StatusTooManyRequests).JSON(fiber.Map{
                "error": "Rate limit exceeded",
            })
        },
    })
}

// In your main function:
app.Use("/api/auth", setupRateLimiter())

Instead of sending the refresh token in the response body, store it in a secure HTTP-only cookie:

func loginHandler(c *fiber.Ctx) error {
    // Authenticate user...

    tokenPair, err := generateTokenPair(user.ID)
    if err != nil {
        return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Failed to generate tokens"})
    }

    c.Cookie(&fiber.Cookie{
        Name:     "refresh_token",
        Value:    tokenPair.RefreshToken,
        Expires:  time.Now().Add(time.Hour * 24 * 7),
        HTTPOnly: true,
        Secure:   true,
        SameSite: "Strict",
    })

    return c.JSON(fiber.Map{
        "access_token": tokenPair.AccessToken,
    })
}

func refreshTokenHandler(c *fiber.Ctx) error {
    refreshToken := c.Cookies("refresh_token")
    if refreshToken == "" {
        return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Refresh token is required"})
    }

    // Verify and parse the refresh token...

    newTokenPair, err := generateTokenPair(userID)
    if err != nil {
        return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Failed to generate new tokens"})
    }

    c.Cookie(&fiber.Cookie{
        Name:     "refresh_token",
        Value:    newTokenPair.RefreshToken,
        Expires:  time.Now().Add(time.Hour * 24 * 7),
        HTTPOnly: true,
        Secure:   true,
        SameSite: "Strict",
    })

    return c.JSON(fiber.Map{
        "access_token": newTokenPair.AccessToken,
    })
}

5. Middleware for Role-Based Access Control (RBAC)

Implement RBAC to control access to different parts of your API:

func RoleMiddleware(requiredRole string) fiber.Handler {
    return func(c *fiber.Ctx) error {
        user := c.Locals("user").(*User)
        if user.Role != requiredRole {
            return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
                "error": "Insufficient permissions",
            })
        }
        return c.Next()
    }
}

// Usage:
app.Get("/admin", verifyJWT, RoleMiddleware("admin"), adminHandler)

Putting It All Together

Here’s a comprehensive example that incorporates all these advanced features:

package main

import (
    "github.com/gofiber/fiber/v2"
    "github.com/golang-jwt/jwt/v4"
    "github.com/go-redis/redis/v8"
    "github.com/gofiber/fiber/v2/middleware/limiter"
    "time"
    "context"
)

var (
    jwtSecret   = []byte("your-secret-key")
    redisClient *redis.Client
)

type User struct {
    ID       string
    Username string
    Password string
    Role     string
}

type TokenPair struct {
    AccessToken  string
    RefreshToken string
}

func main() {
    app := fiber.New()
    initRedis()

    auth := app.Group("/api/auth")
    auth.Use(setupRateLimiter())
    auth.Post("/login", loginHandler)
    auth.Post("/refresh", refreshTokenHandler)
    auth.Post("/logout", verifyJWT, logoutHandler)

    api := app.Group("/api")
    api.Use(verifyJWT)
    api.Get("/user", getUserHandler)
    api.Get("/admin", RoleMiddleware("admin"), adminHandler)

    app.Listen(":3000")
}

func initRedis() {
    redisClient = redis.NewClient(&redis.Options{
        Addr: "localhost:6379",
    })
}

func setupRateLimiter() fiber.Handler {
    return limiter.New(limiter.Config{
        Max:        5,
        Expiration: 1 * time.Minute,
        KeyGenerator: func(c *fiber.Ctx) string {
            return c.IP()
        },
        LimitReached: func(c *fiber.Ctx) error {
            return c.Status(fiber.StatusTooManyRequests).JSON(fiber.Map{
                "error": "Rate limit exceeded",
            })
        },
    })
}

func generateTokenPair(userID string) (TokenPair, error) {
    // Implementation from earlier example
}

func verifyJWT(c *fiber.Ctx) error {
    // Implementation from earlier example
}

func RoleMiddleware(requiredRole string) fiber.Handler {
    // Implementation from earlier example
}

func loginHandler(c *fiber.Ctx) error {
    // Authenticate user...
    user := &User{ID: "123", Username: "johndoe", Role: "user"}

    tokenPair, err := generateTokenPair(user.ID)
    if err != nil {
        return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Failed to generate tokens"})
    }

    c.Cookie(&fiber.Cookie{
        Name:     "refresh_token",
        Value:    tokenPair.RefreshToken,
        Expires:  time.Now().Add(time.Hour * 24 * 7),
        HTTPOnly: true,
        Secure:   true,
        SameSite: "Strict",
    })

    return c.JSON(fiber.Map{
        "access_token": tokenPair.AccessToken,
    })
}

func refreshTokenHandler(c *fiber.Ctx) error {
    // Implementation from earlier example
}

func logoutHandler(c *fiber.Ctx) error {
    // Implementation from earlier example
}

func getUserHandler(c *fiber.Ctx) error {
    user := c.Locals("user").(*User)
    return c.JSON(user)
}

func adminHandler(c *fiber.Ctx) error {
    return c.SendString("Welcome to the admin panel")
}

func revokeToken(token string, expiration time.Duration) error {
    ctx := context.Background()
    return redisClient.Set(ctx, "revoked:"+token, true, expiration).Err()
}

func isTokenRevoked(token string) bool {
    ctx := context.Background()
    _, err := redisClient.Get(ctx, "revoked:"+token).Result()
    return err == nil
}

Conclusion

This advanced implementation of JWT authentication in Go Fiber provides a robust and secure system for managing user authentication and authorization. By incorporating refresh tokens, token revocation, rate limiting, secure cookies, and role-based access control, you’ve significantly enhanced the security and functionality of your authentication system.

Remember to always follow best practices:

  • Use HTTPS in production
  • Regularly rotate your JWT secret keys
  • Implement proper error handling and logging
  • Consider using a separate service for authentication in larger applications
  • Keep your dependencies up to date

With these advanced techniques, your Go Fiber application is now equipped with a state-of-the-art authentication system capable of handling complex security requirements.