[go: nahoru, domu]

Skip to content

Commit

Permalink
fix(bizerr): support biz status error for streaming mode (#1238)
Browse files Browse the repository at this point in the history
  • Loading branch information
jayantxie committed Jan 26, 2024
1 parent e63ed49 commit 1167bb5
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 16 deletions.
11 changes: 8 additions & 3 deletions client/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func (kc *kClient) invokeStreamingEndpoint() (endpoint.Endpoint, error) {
if err != nil {
return
}
clientStream := newStream(st, kc, kc.getStreamingMode(ri), sendEndpoint, recvEndpoint)
clientStream := newStream(st, kc, ri, kc.getStreamingMode(ri), sendEndpoint, recvEndpoint)
resp.(*streaming.Result).Stream = clientStream
return
}, nil
Expand All @@ -107,6 +107,7 @@ func (kc *kClient) getStreamingMode(ri rpcinfo.RPCInfo) serviceinfo.StreamingMod
type stream struct {
stream streaming.Stream
kc *kClient
ri rpcinfo.RPCInfo

streamingMode serviceinfo.StreamingMode
sendEndpoint endpoint.SendEndpoint
Expand All @@ -116,12 +117,13 @@ type stream struct {

var _ streaming.WithDoFinish = (*stream)(nil)

func newStream(s streaming.Stream, kc *kClient, mode serviceinfo.StreamingMode,
sendEP endpoint.SendEndpoint, recvEP endpoint.RecvEndpoint,
func newStream(s streaming.Stream, kc *kClient, ri rpcinfo.RPCInfo,
mode serviceinfo.StreamingMode, sendEP endpoint.SendEndpoint, recvEP endpoint.RecvEndpoint,
) *stream {
return &stream{
stream: s,
kc: kc,
ri: ri,
streamingMode: mode,
sendEndpoint: sendEP,
recvEndpoint: recvEP,
Expand Down Expand Up @@ -161,6 +163,9 @@ func (s *stream) Context() context.Context {
// If an error is returned, stream.DoFinish() will be called to record the end of stream
func (s *stream) RecvMsg(m interface{}) (err error) {
err = s.recvEndpoint(s.stream, m)
if err == nil {
err = s.ri.Invocation().BizStatusErr()
}
if err != nil || s.streamingMode == serviceinfo.StreamingClient {
s.DoFinish(err)
}
Expand Down
28 changes: 15 additions & 13 deletions client/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func TestStreaming(t *testing.T) {
cliInfo.ConnPool = connpool
s, _ := remotecli.NewStream(ctx, mockRPCInfo, new(mocks.MockCliTransHandler), cliInfo)
stream := newStream(
s, kc, serviceinfo.StreamingBidirectional,
s, kc, mockRPCInfo, serviceinfo.StreamingBidirectional,
func(stream streaming.Stream, message interface{}) (err error) {
return stream.SendMsg(message)
},
Expand Down Expand Up @@ -214,6 +214,7 @@ func Test_newStream(t *testing.T) {
s := newStream(
st,
kc,
ri,
serviceinfo.StreamingClient,
func(stream streaming.Stream, message interface{}) (err error) {
return sendErr
Expand Down Expand Up @@ -247,7 +248,7 @@ func Test_stream_Header(t *testing.T) {
},
}

s := newStream(st, &kClient{}, serviceinfo.StreamingBidirectional, nil, nil)
s := newStream(st, &kClient{}, nil, serviceinfo.StreamingBidirectional, nil, nil)
md, err := s.Header()

test.Assert(t, err == nil)
Expand Down Expand Up @@ -278,7 +279,7 @@ func Test_stream_Header(t *testing.T) {
},
}

s := newStream(st, kc, serviceinfo.StreamingBidirectional, nil, nil)
s := newStream(st, kc, ri, serviceinfo.StreamingBidirectional, nil, nil)
md, err := s.Header()

test.Assert(t, err == headerErr)
Expand All @@ -289,7 +290,8 @@ func Test_stream_Header(t *testing.T) {

func Test_stream_RecvMsg(t *testing.T) {
t.Run("no-error", func(t *testing.T) {
s := newStream(&mockStream{}, &kClient{}, serviceinfo.StreamingBidirectional, nil,
mockRPCInfo := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("mock_service", "mock_method"), nil, nil)
s := newStream(&mockStream{}, &kClient{}, mockRPCInfo, serviceinfo.StreamingBidirectional, nil,
func(stream streaming.Stream, message interface{}) (err error) {
return nil
},
Expand All @@ -301,7 +303,7 @@ func Test_stream_RecvMsg(t *testing.T) {
})

t.Run("no-error-client-streaming", func(t *testing.T) {
ri := rpcinfo.NewRPCInfo(nil, nil, nil, nil, rpcinfo.NewRPCStats())
ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("mock_service", "mock_method"), nil, rpcinfo.NewRPCStats())
st := &mockStream{
ctx: rpcinfo.NewCtxWithRPCInfo(context.Background(), ri),
}
Expand All @@ -319,7 +321,7 @@ func Test_stream_RecvMsg(t *testing.T) {
},
}

s := newStream(st, kc, serviceinfo.StreamingClient, nil,
s := newStream(st, kc, ri, serviceinfo.StreamingClient, nil,
func(stream streaming.Stream, message interface{}) (err error) {
return nil
},
Expand Down Expand Up @@ -350,7 +352,7 @@ func Test_stream_RecvMsg(t *testing.T) {
},
}

s := newStream(st, kc, serviceinfo.StreamingBidirectional, nil,
s := newStream(st, kc, ri, serviceinfo.StreamingBidirectional, nil,
func(stream streaming.Stream, message interface{}) (err error) {
return recvErr
},
Expand All @@ -364,7 +366,7 @@ func Test_stream_RecvMsg(t *testing.T) {

func Test_stream_SendMsg(t *testing.T) {
t.Run("no-error", func(t *testing.T) {
s := newStream(&mockStream{}, &kClient{}, serviceinfo.StreamingBidirectional,
s := newStream(&mockStream{}, &kClient{}, nil, serviceinfo.StreamingBidirectional,
func(stream streaming.Stream, message interface{}) (err error) {
return nil
},
Expand Down Expand Up @@ -396,7 +398,7 @@ func Test_stream_SendMsg(t *testing.T) {
},
}

s := newStream(st, kc, serviceinfo.StreamingBidirectional,
s := newStream(st, kc, ri, serviceinfo.StreamingBidirectional,
func(stream streaming.Stream, message interface{}) (err error) {
return sendErr
},
Expand All @@ -416,7 +418,7 @@ func Test_stream_Close(t *testing.T) {
called = true
return nil
},
}, &kClient{}, serviceinfo.StreamingBidirectional, nil, nil)
}, &kClient{}, nil, serviceinfo.StreamingBidirectional, nil, nil)

err := s.Close()

Expand All @@ -438,7 +440,7 @@ func Test_stream_DoFinish(t *testing.T) {
TracerCtl: ctl,
},
}
s := newStream(st, kc, serviceinfo.StreamingBidirectional, nil, nil)
s := newStream(st, kc, ri, serviceinfo.StreamingBidirectional, nil, nil)

finishCalled := false
err := errors.New("any err")
Expand All @@ -465,7 +467,7 @@ func Test_stream_DoFinish(t *testing.T) {
TracerCtl: ctl,
},
}
s := newStream(st, kc, serviceinfo.StreamingBidirectional, nil, nil)
s := newStream(st, kc, ri, serviceinfo.StreamingBidirectional, nil, nil)

finishCalled := false
err := errors.New("any err")
Expand All @@ -492,7 +494,7 @@ func Test_stream_DoFinish(t *testing.T) {
TracerCtl: ctl,
},
}
s := newStream(st, kc, serviceinfo.StreamingBidirectional, nil, nil)
s := newStream(st, kc, ri, serviceinfo.StreamingBidirectional, nil, nil)

finishCalled := false
expectedErr := errors.New("error")
Expand Down
4 changes: 4 additions & 0 deletions internal/mocks/serviceinfo.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ func newServiceInfo() *serviceinfo.ServiceInfo {
"mockException": serviceinfo.NewMethodInfo(mockExceptionHandler, NewMockArgs, newMockExceptionResult, false),
"mockError": serviceinfo.NewMethodInfo(mockErrorHandler, NewMockArgs, NewMockResult, false),
"mockOneway": serviceinfo.NewMethodInfo(mockOnewayHandler, NewMockArgs, nil, true),
"mockStreaming": serviceinfo.NewMethodInfo(
nil, nil, nil, false,
serviceinfo.WithStreamingMode(serviceinfo.StreamingBidirectional),
),
}

svcInfo := &serviceinfo.ServiceInfo{
Expand Down

0 comments on commit 1167bb5

Please sign in to comment.