package grpc import ( "context" "fmt" "net" "github.com/sorti/openspeak/internal/auth" "github.com/sorti/openspeak/internal/channel" "github.com/sorti/openspeak/internal/logger" "github.com/sorti/openspeak/internal/presence" "github.com/sorti/openspeak/internal/voice" "google.golang.org/grpc" "google.golang.org/grpc/metadata" ) // Server wraps the gRPC server and handlers type Server struct { grpc *grpc.Server listener net.Listener logger *logger.Logger tokenManager *auth.TokenManager channelManager *channel.Manager presenceManager *presence.Manager voiceRouter *voice.Router port int } // NewServer creates a new gRPC server func NewServer(port int, log *logger.Logger, tm *auth.TokenManager, cm *channel.Manager, pm *presence.Manager, vr *voice.Router) (*Server, error) { listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) if err != nil { return nil, fmt.Errorf("failed to listen on port %d: %w", port, err) } grpcServer := grpc.NewServer( grpc.ChainUnaryInterceptor( authUnaryInterceptor(log, tm), ), ) s := &Server{ grpc: grpcServer, listener: listener, logger: log, tokenManager: tm, channelManager: cm, presenceManager: pm, voiceRouter: vr, port: port, } // Register service handlers registerAuthHandlers(grpcServer, s) registerChannelHandlers(grpcServer, s) registerPresenceHandlers(grpcServer, s) registerVoiceHandlers(grpcServer, s) return s, nil } // Start starts the gRPC server func (s *Server) Start() error { s.logger.Info(fmt.Sprintf("Starting gRPC server on port %d", s.port)) return s.grpc.Serve(s.listener) } // Stop stops the gRPC server func (s *Server) Stop() { s.logger.Info("Stopping gRPC server") s.grpc.GracefulStop() } // authUnaryInterceptor validates tokens on all RPC calls func authUnaryInterceptor(log *logger.Logger, tm *auth.TokenManager) grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { // Skip auth for login if info.FullMethod == "/openspeak.v1.AuthService/Login" { return handler(ctx, req) } // Extract token from metadata token := extractToken(ctx) if token == "" { log.Warn("Missing token in request") return nil, fmt.Errorf("unauthorized: missing token") } // Validate token _, err := tm.ValidateToken(token) if err != nil { log.Warn(fmt.Sprintf("Invalid token: %v", err)) return nil, fmt.Errorf("unauthorized: %w", err) } // Store token in context for handlers ctx = context.WithValue(ctx, "token", token) return handler(ctx, req) } } // extractToken extracts token from gRPC metadata func extractToken(ctx context.Context) string { // Try to get from gRPC metadata first md, ok := metadata.FromIncomingContext(ctx) if ok { tokens := md.Get("authorization") if len(tokens) > 0 { return tokens[0] } } // Fallback: check context value if token, ok := ctx.Value("token").(string); ok { return token } return "" } // GetTokenManager returns the token manager func (s *Server) GetTokenManager() *auth.TokenManager { return s.tokenManager } // GetChannelManager returns the channel manager func (s *Server) GetChannelManager() *channel.Manager { return s.channelManager } // GetPresenceManager returns the presence manager func (s *Server) GetPresenceManager() *presence.Manager { return s.presenceManager } // GetVoiceRouter returns the voice router func (s *Server) GetVoiceRouter() *voice.Router { return s.voiceRouter } // GetLogger returns the logger func (s *Server) GetLogger() *logger.Logger { return s.logger }