package auth import ( "testing" "time" ) func TestGenerateToken(t *testing.T) { tests := []struct { name string wantErr bool }{ { name: "generate valid token", wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { token, err := GenerateToken() if (err != nil) != tt.wantErr { t.Errorf("GenerateToken() error = %v, wantErr %v", err, tt.wantErr) return } // Verify token is 64 characters (32 bytes hex encoded) if len(token) != 64 { t.Errorf("GenerateToken() returned token of length %d, expected 64", len(token)) } // Verify tokens are unique token2, _ := GenerateToken() if token == token2 { t.Error("GenerateToken() generated duplicate tokens") } }) } } func TestAddToken(t *testing.T) { tests := []struct { name string token string userID string permissions []string wantErr bool }{ { name: "add valid token", token: "test-token-123", userID: "user-1", permissions: []string{"admin"}, wantErr: false, }, { name: "add user token", token: "user-token-456", userID: "user-2", permissions: []string{"channels:create", "channels:join"}, wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tm := NewTokenManager() tm.AddToken(tt.token, tt.userID, tt.permissions) // Verify token was added info, err := tm.ValidateToken(tt.token) if err != nil { t.Errorf("ValidateToken() error = %v", err) return } if info.UserID != tt.userID { t.Errorf("Expected userID %s, got %s", tt.userID, info.UserID) } if len(info.Permissions) != len(tt.permissions) { t.Errorf("Expected %d permissions, got %d", len(tt.permissions), len(info.Permissions)) } }) } } func TestValidateToken(t *testing.T) { tm := NewTokenManager() tests := []struct { name string token string setup func() wantErr bool errType error }{ { name: "valid token", token: "valid-token", setup: func() { tm.AddToken("valid-token", "user-1", []string{"admin"}) }, wantErr: false, }, { name: "invalid token", token: "nonexistent-token", setup: func() {}, wantErr: true, errType: ErrInvalidToken, }, { name: "revoked token", token: "revoked-token", setup: func() { tm.AddToken("revoked-token", "user-2", []string{"user"}) tm.RevokeToken("revoked-token") }, wantErr: true, errType: ErrInvalidToken, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tt.setup() info, err := tm.ValidateToken(tt.token) if (err != nil) != tt.wantErr { t.Errorf("ValidateToken() error = %v, wantErr %v", err, tt.wantErr) return } if tt.wantErr && err != tt.errType { t.Errorf("ValidateToken() error = %v, expected %v", err, tt.errType) } if !tt.wantErr && info == nil { t.Error("ValidateToken() returned nil info for valid token") } }) } } func TestRevokeToken(t *testing.T) { tm := NewTokenManager() token := "revoke-test-token" // Add token tm.AddToken(token, "user-1", []string{"admin"}) // Verify it's valid _, err := tm.ValidateToken(token) if err != nil { t.Errorf("ValidateToken() before revoke error = %v", err) return } // Revoke token err = tm.RevokeToken(token) if err != nil { t.Errorf("RevokeToken() error = %v", err) return } // Verify it's now invalid _, err = tm.ValidateToken(token) if err != ErrInvalidToken { t.Errorf("ValidateToken() after revoke error = %v, expected %v", err, ErrInvalidToken) } // Test revoking nonexistent token err = tm.RevokeToken("nonexistent-token") if err != ErrInvalidToken { t.Errorf("RevokeToken() nonexistent error = %v, expected %v", err, ErrInvalidToken) } } func TestListTokens(t *testing.T) { tm := NewTokenManager() // Add multiple tokens tm.AddToken("token-1", "user-1", []string{"admin"}) tm.AddToken("token-2", "user-2", []string{"user"}) tm.AddToken("token-3", "user-3", []string{"user"}) tokens := tm.ListTokens() if len(tokens) != 3 { t.Errorf("Expected 3 tokens, got %d", len(tokens)) } } func TestHasPermission(t *testing.T) { tm := NewTokenManager() tests := []struct { name string token string permission string setup func() want bool wantErr bool }{ { name: "admin has all permissions", token: "admin-token", permission: "channels:create", setup: func() { tm.AddToken("admin-token", "user-1", []string{"admin"}) }, want: true, wantErr: false, }, { name: "user has specific permission", token: "user-token", permission: "channels:join", setup: func() { tm.AddToken("user-token", "user-2", []string{"channels:join", "channels:create"}) }, want: true, wantErr: false, }, { name: "user missing permission", token: "limited-token", permission: "users:manage", setup: func() { tm.AddToken("limited-token", "user-3", []string{"channels:join"}) }, want: false, wantErr: false, }, { name: "invalid token", token: "bad-token", permission: "channels:create", setup: func() {}, want: false, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tm := NewTokenManager() // Fresh manager for each test tt.setup = func() { switch tt.name { case "admin has all permissions": tm.AddToken("admin-token", "user-1", []string{"admin"}) case "user has specific permission": tm.AddToken("user-token", "user-2", []string{"channels:join", "channels:create"}) case "user missing permission": tm.AddToken("limited-token", "user-3", []string{"channels:join"}) } } tt.setup() got, err := tm.HasPermission(tt.token, tt.permission) if (err != nil) != tt.wantErr { t.Errorf("HasPermission() error = %v, wantErr %v", err, tt.wantErr) return } if got != tt.want { t.Errorf("HasPermission() = %v, want %v", got, tt.want) } }) } } func TestTokenExpiration(t *testing.T) { tm := NewTokenManager() token := "expiring-token" // Create token with expiration in the past tm.AddToken(token, "user-1", []string{"admin"}) tokenInfo, _ := tm.ValidateToken(token) // Manually set expiration to past pastTime := time.Now().Add(-1 * time.Hour) tokenInfo.ExpiresAt = &pastTime // Should be expired now _, err := tm.ValidateToken(token) if err != ErrTokenExpired { t.Errorf("Expected ErrTokenExpired, got %v", err) } } func TestTokenConcurrency(t *testing.T) { tm := NewTokenManager() // Create tokens concurrently for i := 0; i < 100; i++ { go func(index int) { token, _ := GenerateToken() tm.AddToken(token, "user-1", []string{"admin"}) }(i) } // Give goroutines time to complete time.Sleep(100 * time.Millisecond) tokens := tm.ListTokens() if len(tokens) != 100 { t.Errorf("Expected 100 tokens, got %d", len(tokens)) } }