gRPC interceptor

go-grpc-middleware这个项目对gRPC的interceptor进行了封装,支持多个拦截器的链式组装。最近正好在看gRPC的拦截器和这个grpc-middleware,发现跟我们做的handlerChain类似,其他go-web框架通常也有自身的filter功能实现对其入口出口消息的拦截,这里总结一下实现方式。

gRPC Interceptor

客户端和服务端都可以在初始化时注册拦截器,现在middleware里提供的中间件拦截器包含auth,添加context的request header的ctxtags,支持zap/logrus日志库,支持prometheus监控上报,支持opentracing,客户端支持重试retry,服务端支持校验和恢复等功能。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// ServerInterceptor
s := grpc.NewServer(grpc.StreamInterceptor(grpc_ctxtags.StreamServerInterceptor()),
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
grpc_ctxtags.UnaryServerInterceptor(),
grpc_opentracing.UnaryServerInterceptor(),
grpc_prometheus.UnaryServerInterceptor,
grpc_zap.UnaryServerInterceptor(zapLogger),
grpc_auth.UnaryServerInterceptor(myAuthFunction),
grpc_recovery.UnaryServerInterceptor(),
)),
)
// ClientInterceptor
conn, err := grpc.Dial(address, grpc.WithInsecure(),
grpc.WithStreamInterceptor(grpc_retry.StreamClientInterceptor),
grpc.WithUnaryInterceptor(grpc_retry.UnaryServerInterceptor),
)

server

server侧在注册的handler里留有interceptor的接口,当它不为空时调用interceptor去处理,这时service原有的handler作为参数交给interceptor,供其调用。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
func _Greeter_SayHello_Handler(srv interface{}, ctx context.Context, 
dec func(interface{}) error,
interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(HelloRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(GreeterServer).SayHello(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/helloworld.Greeter/SayHello",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(GreeterServer).SayHello(ctx, req.(*HelloRequest))
}
return interceptor(ctx, in, info, handler)
}

interceptor就是一个func类型,其返回值和Handler的返回值相同。这是opentracing实现的interceptor,通常在拦截器中执行handler前后会加入拦截器自己的逻辑。

1
2
type UnaryServerInterceptor func(ctx context.Context, req interface{}, 
info *UnaryServerInfo, handler UnaryHandler) (resp interface{}, err error)
1
2
3
4
5
6
7
8
9
10
func UnaryServerInterceptor(opts ...Option) grpc.UnaryServerInterceptor {
o := evaluateOptions(opts)
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler) (interface{}, error) {
newCtx, serverSpan := newServerSpanFromInbound(ctx, o.tracer, info.FullMethod)
resp, err := handler(newCtx, req)
finishServerSpan(ctx, serverSpan, err)
return resp, err
}
}

client

客户端的调用经过grpc.Invoke,在invoke中实现了interceptor的调用。跟server侧的实现类似,client拦截器参数包含invoke,供拦截器实现方调用。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
func (c *userClient) QueryUser(ctx context.Context, in *UserRequest, opts ...grpc.CallOption) (*UserResponse, error) {
out := new(UserResponse)
err := grpc.Invoke(ctx, "/User/QueryUser", in, out, c.cc, opts...)
if err != nil {
return nil, err
}
return out, nil
}

func Invoke(ctx context.Context, method string, args, reply interface{},
cc *ClientConn, opts ...CallOption) error {
if cc.dopts.unaryInt != nil {
return cc.dopts.unaryInt(ctx, method, args, reply, cc, invoke, opts...)
}
return invoke(ctx, method, args, reply, cc, opts...)
}

UnaryClientInterceptor也是一个func类型,其返回值和invoke的返回值相同,这是opentracing实现的interceptor,通常在客户端拦截器在执行invoker的前后会加入拦截器自己的逻辑。

1
2
3
type UnaryClientInterceptor func(ctx context.Context, method string, req, 
reply interface{}, cc *ClientConn, invoker UnaryInvoker,
opts ...CallOption) error
1
2
3
4
5
6
7
8
9
10
11
func UnaryClientInterceptor(opts ...Option) grpc.UnaryClientInterceptor {
o := evaluateOptions(opts)
return func(parentCtx context.Context, method string, req,
reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker,
opts ...grpc.CallOption) error {
newCtx, clientSpan := newClientSpanFromContext(parentCtx, o.tracer, method)
err := invoker(newCtx, method, req, reply, cc, opts...)
finishClientSpan(clientSpan, err)
return err
}
}

chain

因为gRPC提供的接口只支持一个拦截器,但通常希望按一定顺序去执行多个拦截器。于是 go-grpc-middleware这个项目提供了grpc_middleware.ChainUnaryServer方法来实现链式处理。以Server为例,Chain集合了所有拦截器,当数量大于1的时候,从interceptors[0]开始递归,每个转入interceptors[i]的handler都是经过后续拦截器包装的chainHandler。curI作为闭包的参数,跟随递归过程自加,依次调用chainHandler中的interceptor[curI],直到最后一个调用真正的handler。之后返回curlI自减,保证依次反向执行interceptors在handler之后要执行的操作。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
func ChainUnaryServer(interceptors ...grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor {
n := len(interceptors)

if n > 1 {
lastI := n - 1
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
var (
chainHandler grpc.UnaryHandler
curI int
)
chainHandler = func(currentCtx context.Context, currentReq interface{}) (interface{}, error) {
if curI == lastI {
return handler(currentCtx, currentReq)
}
curI++
resp, err := interceptors[curI](currentCtx, currentReq, info, chainHandler)
curI--
return resp, err
}
return interceptors[0](ctx, req, info, chainHandler)
}
}

if n == 1 {
return interceptors[0]
}

// n == 0; Dummy interceptor maintained for backward compatibility to avoid returning nil.
return func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
return handler(ctx, req)
}
}

service-center chain

service-center里有单独的interceptor,是用来处理routing前的准入控制,chain提供的handler是用来做路由后的处理。下面的代码来自roa.GetRouter().ServeHTTP(w, r)

服务中心的是服务端的处理链,invocation通过NewChain来初始化,将注册的handler加入其中。在每个handler里通过操作context来处理请求和响应。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
inv := chain.NewInvocation(ctx, chain.NewChain(SERVER_CHAIN_NAME, hs))
inv.WithContext(CTX_RESPONSE, w).
WithContext(CTX_REQUEST, r).
WithContext(CTX_MATCH_PATTERN, ph.Path).
WithContext(CTX_MATCH_FUNC, ph.Name)
inv.Next(chain.WithFunc(func(ret chain.Result) {
defer func() {
err := ret.Err
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
}()
if ret.OK {
ph.ServeHTTP(w, r)
}
}))

chain

invocation中的chain跟go-chassis中实现类似,Next方法(syncNext)依次调用注册的Handler。当chain里的handler都处理完成后调用i.Success开始Callback。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
type Chain struct {
name string
handlers []Handler
currentIndex int
}

func (c *Chain) syncNext(i *Invocation) {
defer func() {
itf := recover()
if itf == nil {
return
}
util.LogPanic(itf)
i.Fail(errorsEx.RaiseError(itf))
}()
// handler已遍历完成
if c.currentIndex >= len(c.handlers)-1 {
i.Success()
return
}
c.currentIndex += 1
c.handlers[c.currentIndex].Handle(i)
}

以下是一个metricHandler的实现,它只是调用invocation.Next并且通过WithAsyncFunc实现这个异步函数的注册。其中w r等参数以闭包的形式(地址)与func一起注册到invocation中。其实在i.Next之前的内容是在server端业务逻辑之前执行的内容,其他的Handler比如authHandler就只有前面部分,而没有callback注册。还有一些callback需要原始参数,可以在handle里先计算并以闭包的形式传递。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
type MetricsHandler struct {}

func (h *MetricsHandler) Handle(i *chain.Invocation) {
w, r := i.Context().Value(rest.CTX_RESPONSE).(http.ResponseWriter),
i.Context().Value(rest.CTX_REQUEST).(*http.Request)
i.Next(chain.WithAsyncFunc(func(ret chain.Result) {
start, ok := i.Context().Value(svr.CTX_START_TIMESTAMP).(time.Time)
if !ok {
return
}
svr.ReportRequestCompleted(w, r, start)
util.LogNilOrWarnf(start, "%s %s", r.Method, r.RequestURI)
}))
}

func RegisterHandlers() {
chain.RegisterHandler(rest.SERVER_CHAIN_NAME, &MetricsHandler{})
}

invoation

invocation的处理链数据结构包含Callback,其中Calback以匿名组合的方式包含回调Func和Aysnc是否异步的flag。

1
2
3
4
5
6
7
8
9
type Invocation struct {
Callback
context *util.StringContext
chain Chain
}
type Callback struct {
Func func(r Result)
Async bool
}

invocation通过WithFunc这种函数式参数的方式,每个Handler中通过WithAsyncFunc通过setCallback注册到了invocation的Callback中。在chain的Next中不断地在回调函数i.Func中添加各handler的AsyncFunc。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
func (i *Invocation) Next(opts ...InvocationOption) {
var op InvocationOp
for _, opt := range opts {
op = opt(op)
}
// 注册每个handler的回调函数
i.setCallback(op.Func, op.Async)
i.chain.Next(i)
}

func (i *Invocation) setCallback(f func(r Result), async bool) {
if f == nil {
return
}

if i.Func == nil {
i.Func = f
i.Async = async
return
}
// 回调函数以注册的次序写在Callback的Func中
cb := i.Func
i.Func = func(r Result) {
cb(r)
callback(f, async, r)
}
}

func callback(f func(r Result), async bool, r Result) {
c := Callback{Func: f, Async: async}
c.Invoke(r)
}

callback

当最后一个chain成功后开始i.Success,也就是通过cb.Invoke开始回调。cb中的callback首先是ph.ServeHTTP(w, r),之后是handler注册的aysncFunc或syncFunc,如果是sync的则同步调用,如果异步则通过new一个goroutine来执行。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
func (cb *Callback) Success(args ...interface{}) {
cb.Invoke(Result{
OK: true,
Args: args,
})
}

func (cb *Callback) Invoke(r Result) {
if cb.Async {
go syncInvoke(cb.Func, r)
return
}
syncInvoke(cb.Func, r)
}

func syncInvoke(f func(r Result), r Result) {
defer func() {
if itf := recover(); itf != nil {
util.LogPanic(itf)
}
}()
if f == nil {
util.Logger().Errorf(nil, "Callback function is nil. result: %s,", r)
return
}
f(r)
}

go-chassis chain

go-chassis支持RESTful和RPC两种通信机制,对他们的客户端和服务端的调用实现了通用的处理链,支持配置和扩展。客户端的handlerChain里最后一个是transport,用来实现最后请求的invoke。

go-chassis里没有在业务逻辑之后添加handler的实例,但是跟service-center实现类似可以在chain.Next中添加回调逻辑,将可以在业务之后按照反序进行回调。

client

go-chassis中实现的HandlerChain支持RESTful/RPC的客户端服务端。其中客户端的RPC和RESTful方式相同,从handler包的ChainMap中获取注册的chain

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
func (ri *RestInvoker) ContextDo(ctx context.Context, req *rest.Request,
options ...InvocationOption) (*rest.Response, error) {
inv := invocation.CreateInvocation()
// create invoation with req and options
c, _ := handler.GetChain(common.Consumer, ri.opts.ChainName)
c.Next(inv, func(ir *invocation.InvocationResponse) error {
err = ir.Err
return err
})
return resp, err
}

func (ri *RPCInvoker) Invoke(ctx context.Context, microServiceName, schemaID,
operationID string, arg interface{}, reply interface{},
options ...InvocationOption) error {
i := invocation.CreateInvocation()
// create invoation with arg reply and options
c, _ := handler.GetChain(common.Consumer, ri.opts.ChainName)
c.Next(i, func(ir *invocation.InvocationResponse) error {
err = ir.Err
reply = ir.Result
return err
})
return err
}

server

以下是RESTful版本的server端处理,RPC版本的类似,这个server端的处理只支持在serve之前添加处理模块,而不能在这之后。这里可以对比一下service-center的实现,见下一节。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
handle := func(req *restful.Request, rep *restful.Response) {
c, err := handler.GetChain(common.Provider, r.opts.ChainName)
//todo: use it for hystric
inv := invocation.Invocation{
MicroServiceName: r.microServiceName,
SourceMicroService: req.HeaderParameter(common.HeaderSourceName),
Args: req,
Protocol: common.ProtocolRest,
SchemaID: schemaName,
OperationID: method.Name,
}
bs := NewBaseServer(context.TODO())
bs.req = req
bs.resp = rep
c.Next(&inv, func(ir *invocation.InvocationResponse) error {
if ir.Err != nil {
return ir.Err
}
method.Func.Call([]reflect.Value{schemaValue, reflect.ValueOf(bs)})
if bs.resp.StatusCode() >= http.StatusBadRequest {
return fmt.Errorf("get err from http handle", bs.resp.StatusCode())
}
return nil
})
}

chain

handlerChain是通过文件配置的方式定义的,并在初始化的时候根据处理链类型(provider or consumer)和处理链名称newChain并完成注册。

1
2
3
4
5
6
type Chain struct {
ServiceType string
Name string
Handlers []Handler
HandlerIndex int
}

每个handler实现Handle和Name两个方法,然后在每个Handler末尾继续调用chain.Next,并将callback递归下去,直到最后一个Handler。

1
2
3
4
type Handler interface {
Handle(*Chain, *invocation.Invocation, invocation.ResponseCallBack)
Name() string
}
1
2
3
4
5
6
7
8
9
10
11
12
func (c *Chain) Next(i *invocation.Invocation, f invocation.ResponseCallBack) {
index := c.HandlerIndex
if index >= len(c.Handlers) {
r := &invocation.InvocationResponse{
Err: nil,
}
f(r)
return
}
c.HandlerIndex++
c.Handlers[index].Handle(c, i, f)
}

以下是一个fault-inject的例子,出错后开始回调callback,并将出错handler中的invocatoinResponse通过回调回到invoke中的callback函数,将错误信息error传递到invoke中的err。在Handle过程中,根据需要操作inv的内容,invocation结构体包含了所有server和client通信的相关数据。handler主要可以修改里面的请求Args和响应Reply。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19


func init() { RegisterHandler(FaultHandlerName, FaultHandle) }

func (rl *FaultHandler) Name() string { return "fault-inject" }
func (rl *FaultHandler) Handle(chain *Chain, inv *invocation.Invocation,
cb invocation.ResponseCallBack) {
// fault handler do something...
r := &invocation.InvocationResponse{}
err := faultInject(faultValue, inv)
if err != nil {
r.Err = err
cb(r)
return
}
chain.Next(inv, func(r *invocation.InvocationResponse) error {
return cb(r)
})
}