From a999f34b23ee82bc8c77436db29f84beb3de676e Mon Sep 17 00:00:00 2001 From: Djordje Lukic Date: Mon, 15 Jun 2020 23:21:52 +0200 Subject: [PATCH] Use the context from the metadata if it exists --- server/contextserverstream.go | 41 ++++++++++++ server/interceptor.go | 111 +++++++++++++++++++++++++++++++ server/interceptor_test.go | 121 ++++++++++++++++++++++++++++++++++ server/server.go | 109 +----------------------------- 4 files changed, 275 insertions(+), 107 deletions(-) create mode 100644 server/contextserverstream.go create mode 100644 server/interceptor.go create mode 100644 server/interceptor_test.go diff --git a/server/contextserverstream.go b/server/contextserverstream.go new file mode 100644 index 000000000..858ef5f68 --- /dev/null +++ b/server/contextserverstream.go @@ -0,0 +1,41 @@ +package server + +import ( + "context" + + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" +) + +// A gRPC server stream will only let you get its context but +// there is no way to set a new (augmented context) to the next +// handler (like we do for a unary request). We need to wrap the grpc.ServerSteam +// to be able to set a new context that will be sent to the next stream interceptor. +type contextServerStream struct { + ss grpc.ServerStream + ctx context.Context +} + +func (css *contextServerStream) SetHeader(md metadata.MD) error { + return css.ss.SetHeader(md) +} + +func (css *contextServerStream) SendHeader(md metadata.MD) error { + return css.ss.SendHeader(md) +} + +func (css *contextServerStream) SetTrailer(md metadata.MD) { + css.ss.SetTrailer(md) +} + +func (css *contextServerStream) Context() context.Context { + return css.ctx +} + +func (css *contextServerStream) SendMsg(m interface{}) error { + return css.ss.SendMsg(m) +} + +func (css *contextServerStream) RecvMsg(m interface{}) error { + return css.ss.RecvMsg(m) +} diff --git a/server/interceptor.go b/server/interceptor.go new file mode 100644 index 000000000..d4baf4052 --- /dev/null +++ b/server/interceptor.go @@ -0,0 +1,111 @@ +package server + +import ( + "context" + "errors" + "strings" + + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + + "github.com/docker/api/client" + "github.com/docker/api/config" + apicontext "github.com/docker/api/context" + "github.com/docker/api/context/store" + "github.com/docker/api/server/proxy" +) + +// key is the key where the current docker context is stored in the metadata +// of a gRPC request +const key = "context_key" + +// unaryServerInterceptor configures the context and sends it to the next handler +func unaryServerInterceptor(clictx context.Context) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + currentContext, err := getIncomingContext(ctx) + if err != nil { + currentContext, err = getConfigContext(clictx) + if err != nil { + return nil, err + } + } + configuredCtx, err := configureContext(clictx, currentContext, info.FullMethod) + if err != nil { + return nil, err + } + + return handler(configuredCtx, req) + } +} + +// streamServerInterceptor configures the context and sends it to the next handler +func streamServerInterceptor(clictx context.Context) grpc.StreamServerInterceptor { + return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + currentContext, err := getIncomingContext(ss.Context()) + if err != nil { + currentContext, err = getConfigContext(clictx) + if err != nil { + return err + } + } + ctx, err := configureContext(clictx, currentContext, info.FullMethod) + if err != nil { + return err + } + + return handler(srv, &contextServerStream{ + ss: ss, + ctx: ctx, + }) + } +} + +// Returns the current context from the configuration file +func getConfigContext(ctx context.Context) (string, error) { + configDir := config.Dir(ctx) + configFile, err := config.LoadFile(configDir) + if err != nil { + return "", err + } + return configFile.CurrentContext, nil +} + +// Returns the context set by the caller if any, error otherwise +func getIncomingContext(ctx context.Context) (string, error) { + if md, ok := metadata.FromIncomingContext(ctx); ok { + if key, ok := md[key]; ok { + return key[0], nil + } + } + + return "", errors.New("not found") +} + +// configureContext populates the request context with objects the client +// needs: the context store and the api client +func configureContext(ctx context.Context, currentContext string, method string) (context.Context, error) { + configDir := config.Dir(ctx) + + ctx = apicontext.WithCurrentContext(ctx, currentContext) + + // The contexts service doesn't need the client + if !strings.Contains(method, "/com.docker.api.protos.context.v1.Contexts") { + c, err := client.New(ctx) + if err != nil { + return nil, err + } + + ctx, err = proxy.WithClient(ctx, c) + if err != nil { + return nil, err + } + } + + s, err := store.New(store.WithRoot(configDir)) + if err != nil { + return nil, err + } + ctx = store.WithContextStore(ctx, s) + + return ctx, nil +} diff --git a/server/interceptor_test.go b/server/interceptor_test.go new file mode 100644 index 000000000..c137df357 --- /dev/null +++ b/server/interceptor_test.go @@ -0,0 +1,121 @@ +package server + +import ( + "context" + "io/ioutil" + "os" + "path" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + + "github.com/docker/api/config" + apicontext "github.com/docker/api/context" +) + +type interceptorSuite struct { + suite.Suite + dir string + ctx context.Context +} + +func (is *interceptorSuite) BeforeTest(suiteName, testName string) { + dir, err := ioutil.TempDir("", "example") + require.Nil(is.T(), err) + + ctx := context.Background() + ctx = config.WithDir(ctx, dir) + err = ioutil.WriteFile(path.Join(dir, "config.json"), []byte(`{"currentContext": "default"}`), 0644) + require.Nil(is.T(), err) + + is.dir = dir + is.ctx = ctx +} + +func (is *interceptorSuite) AfterTest(suiteName, tesName string) { + err := os.RemoveAll(is.dir) + require.Nil(is.T(), err) +} + +func (is *interceptorSuite) TestUnaryGetCurrentContext() { + interceptor := unaryServerInterceptor(is.ctx) + + currentContext := is.callUnary(context.Background(), interceptor) + + assert.Equal(is.T(), "default", currentContext) +} + +func (is *interceptorSuite) TestUnaryContextFromMetadata() { + contextName := "test" + + interceptor := unaryServerInterceptor(is.ctx) + reqCtx := context.Background() + reqCtx = metadata.NewIncomingContext(reqCtx, metadata.MD{ + (key): []string{contextName}, + }) + + currentContext := is.callUnary(reqCtx, interceptor) + + assert.Equal(is.T(), contextName, currentContext) +} + +func (is *interceptorSuite) TestStreamGetCurrentContext() { + interceptor := streamServerInterceptor(is.ctx) + + currentContext := is.callStream(context.Background(), interceptor) + + assert.Equal(is.T(), "default", currentContext) +} + +func (is *interceptorSuite) TestStreamContextFromMetadata() { + contextName := "test" + + interceptor := streamServerInterceptor(is.ctx) + reqCtx := context.Background() + reqCtx = metadata.NewIncomingContext(reqCtx, metadata.MD{ + (key): []string{contextName}, + }) + + currentContext := is.callStream(reqCtx, interceptor) + + assert.Equal(is.T(), contextName, currentContext) +} + +func (is *interceptorSuite) callStream(ctx context.Context, interceptor grpc.StreamServerInterceptor) string { + currentContext := "" + err := interceptor(nil, &contextServerStream{ + ctx: ctx, + }, &grpc.StreamServerInfo{ + FullMethod: "/com.docker.api.protos.context.v1.Contexts/test", + }, func(srv interface{}, stream grpc.ServerStream) error { + currentContext = apicontext.CurrentContext(stream.Context()) + return nil + }) + + require.Nil(is.T(), err) + + return currentContext +} + +func (is *interceptorSuite) callUnary(ctx context.Context, interceptor grpc.UnaryServerInterceptor) string { + currentContext := "" + resp, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{ + FullMethod: "/com.docker.api.protos.context.v1.Contexts/test", + }, func(ctx context.Context, req interface{}) (interface{}, error) { + currentContext = apicontext.CurrentContext(ctx) + return nil, nil + }) + + require.Nil(is.T(), err) + require.Nil(is.T(), resp) + + return currentContext +} + +func TestInterceptor(t *testing.T) { + suite.Run(t, new(interceptorSuite)) +} diff --git a/server/server.go b/server/server.go index 3c338db86..4c7d1fc8a 100644 --- a/server/server.go +++ b/server/server.go @@ -35,13 +35,6 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/health" "google.golang.org/grpc/health/grpc_health_v1" - "google.golang.org/grpc/metadata" - - "github.com/docker/api/client" - "github.com/docker/api/config" - apicontext "github.com/docker/api/context" - "github.com/docker/api/context/store" - "github.com/docker/api/server/proxy" ) // New returns a new GRPC server. @@ -55,109 +48,11 @@ func New(ctx context.Context) *grpc.Server { return s } -//CreateListener creates a listener either on tcp://, or local listener, supporting unix:// for unix socket or npipe:// for named pipes on windows +// CreateListener creates a listener either on tcp://, or local listener, +// supporting unix:// for unix socket or npipe:// for named pipes on windows func CreateListener(address string) (net.Listener, error) { if strings.HasPrefix(address, "tcp://") { return net.Listen("tcp", strings.TrimPrefix(address, "tcp://")) } return createLocalListener(address) } - -// unaryServerInterceptor configures the context and sends it to the next handler -func unaryServerInterceptor(clictx context.Context) func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { - return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { - configuredCtx, err := configureContext(clictx, info.FullMethod) - if err != nil { - return nil, err - } - - return handler(configuredCtx, req) - } -} - -// streamServerInterceptor configures the context and sends it to the next handler -func streamServerInterceptor(clictx context.Context) func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { - return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { - ctx, err := configureContext(clictx, info.FullMethod) - if err != nil { - return err - } - - return handler(srv, newServerStream(ctx, ss)) - } -} - -// configureContext populates the request context with objects the client -// needs: the context store and the api client -func configureContext(ctx context.Context, method string) (context.Context, error) { - configDir := config.Dir(ctx) - configFile, err := config.LoadFile(configDir) - if err != nil { - return nil, err - } - - if configFile.CurrentContext != "" { - ctx = apicontext.WithCurrentContext(ctx, configFile.CurrentContext) - } - - // The contexts service doesn't need the client - if !strings.Contains(method, "/com.docker.api.protos.context.v1.Contexts") { - c, err := client.New(ctx) - if err != nil { - return nil, err - } - - ctx, err = proxy.WithClient(ctx, c) - if err != nil { - return nil, err - } - } - - s, err := store.New(store.WithRoot(configDir)) - if err != nil { - return nil, err - } - ctx = store.WithContextStore(ctx, s) - - return ctx, nil -} - -// A gRPC server stream will only let you get its context but -// there is no way to set a new (augmented context) to the next -// handler (like we do for a unary request). We need to wrap the grpc.ServerSteam -// to be able to set a new context that will be sent to the next stream interceptor. -type contextServerStream struct { - s grpc.ServerStream - ctx context.Context -} - -func newServerStream(ctx context.Context, s grpc.ServerStream) grpc.ServerStream { - return &contextServerStream{ - s: s, - ctx: ctx, - } -} - -func (css *contextServerStream) SetHeader(md metadata.MD) error { - return css.s.SetHeader(md) -} - -func (css *contextServerStream) SendHeader(md metadata.MD) error { - return css.s.SendHeader(md) -} - -func (css *contextServerStream) SetTrailer(md metadata.MD) { - css.s.SetTrailer(md) -} - -func (css *contextServerStream) Context() context.Context { - return css.ctx -} - -func (css *contextServerStream) SendMsg(m interface{}) error { - return css.s.SendMsg(m) -} - -func (css *contextServerStream) RecvMsg(m interface{}) error { - return css.s.RecvMsg(m) -}