diff --git a/pkg/client/counter/counter.go b/pkg/client/counter/counter.go index f701d89..f3244b8 100644 --- a/pkg/client/counter/counter.go +++ b/pkg/client/counter/counter.go @@ -16,6 +16,7 @@ package counter import ( "context" + api "github.com/atomix/api/proto/atomix/counter" "github.com/atomix/api/proto/atomix/headers" "github.com/atomix/go-client/pkg/client/primitive" @@ -47,6 +48,9 @@ type Counter interface { // Decrement decrements the counter by the given delta Decrement(ctx context.Context, delta int64) (int64, error) + + // CAS checks the counter value and then updates its current value + CAS(ctx context.Context, expect int64, update int64) (bool, error) } // New creates a new counter for the given partitions @@ -149,6 +153,26 @@ func (c *counter) Decrement(ctx context.Context, delta int64) (int64, error) { return response.(*api.DecrementResponse).NextValue, nil } +func (c *counter) CAS(ctx context.Context, expect int64, update int64) (bool, error) { + response, err := c.instance.DoCommand(ctx, func(ctx context.Context, conn *grpc.ClientConn, header *headers.RequestHeader) (*headers.ResponseHeader, interface{}, error) { + client := api.NewCounterServiceClient(conn) + request := &api.CheckAndSetRequest{ + Header: header, + Expect: expect, + Update: update, + } + response, err := client.CheckAndSet(ctx, request) + if err != nil { + return nil, nil, err + } + return response.Header, response, nil + }) + if err != nil { + return false, err + } + return response.(*api.CheckAndSetResponse).Succeeded, nil +} + func (c *counter) Close(ctx context.Context) error { return c.instance.Close(ctx) } diff --git a/pkg/client/counter/counter_test.go b/pkg/client/counter/counter_test.go index eafea4d..97028b6 100644 --- a/pkg/client/counter/counter_test.go +++ b/pkg/client/counter/counter_test.go @@ -16,10 +16,11 @@ package counter import ( "context" + "testing" + "github.com/atomix/go-client/pkg/client/primitive" "github.com/atomix/go-client/pkg/client/test" "github.com/stretchr/testify/assert" - "testing" ) func TestCounterOperations(t *testing.T) { @@ -69,6 +70,18 @@ func TestCounterOperations(t *testing.T) { assert.NoError(t, err) assert.Equal(t, int64(10), value) + casValue, err := counter.CAS(context.TODO(), 15, 25) + assert.NoError(t, err) + assert.Equal(t, false, casValue) + + casValue, err = counter.CAS(context.TODO(), 10, 20) + assert.NoError(t, err) + assert.Equal(t, true, casValue) + + value, err = counter.Get(context.TODO()) + assert.NoError(t, err) + assert.Equal(t, int64(20), value) + err = counter.Close(context.Background()) assert.NoError(t, err) @@ -80,7 +93,7 @@ func TestCounterOperations(t *testing.T) { value, err = counter1.Get(context.TODO()) assert.NoError(t, err) - assert.Equal(t, int64(10), value) + assert.Equal(t, int64(20), value) err = counter1.Close(context.Background()) assert.NoError(t, err) diff --git a/pkg/client/list/list.go b/pkg/client/list/list.go index 2592353..42e2100 100644 --- a/pkg/client/list/list.go +++ b/pkg/client/list/list.go @@ -18,6 +18,7 @@ import ( "context" "encoding/base64" "errors" + "github.com/atomix/api/proto/atomix/headers" api "github.com/atomix/api/proto/atomix/list" "github.com/atomix/go-client/pkg/client/primitive" @@ -58,6 +59,9 @@ type List interface { // Len gets the length of the list Len(ctx context.Context) (int, error) + // Contains checks whether the list contains a value + Contains(ctx context.Context, value []byte) (bool, error) + // Slice returns a slice of the list from the given start index to the given end index Slice(ctx context.Context, from int, to int) (List, error) @@ -258,6 +262,25 @@ func (l *list) Remove(ctx context.Context, index int) ([]byte, error) { } } +func (l *list) Contains(ctx context.Context, value []byte) (bool, error) { + response, err := l.instance.DoQuery(ctx, func(ctx context.Context, conn *grpc.ClientConn, header *headers.RequestHeader) (*headers.ResponseHeader, interface{}, error) { + client := api.NewListServiceClient(conn) + request := &api.ContainsRequest{ + Header: header, + Value: base64.StdEncoding.EncodeToString(value), + } + response, err := client.Contains(ctx, request) + if err != nil { + return nil, nil, err + } + return response.Header, response, nil + }) + if err != nil { + return false, err + } + return bool(response.(*api.ContainsResponse).Contains), nil +} + func (l *list) Len(ctx context.Context) (int, error) { response, err := l.instance.DoQuery(ctx, func(ctx context.Context, conn *grpc.ClientConn, header *headers.RequestHeader) (*headers.ResponseHeader, interface{}, error) { client := api.NewListServiceClient(conn) diff --git a/pkg/client/list/list_test.go b/pkg/client/list/list_test.go index 41ba7e4..1c7946d 100644 --- a/pkg/client/list/list_test.go +++ b/pkg/client/list/list_test.go @@ -16,10 +16,11 @@ package list import ( "context" + "testing" + "github.com/atomix/go-client/pkg/client/primitive" "github.com/atomix/go-client/pkg/client/test" "github.com/stretchr/testify/assert" - "testing" ) func TestListOperations(t *testing.T) { @@ -53,6 +54,10 @@ func TestListOperations(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "foo", string(value)) + contain, err := list.Contains(context.TODO(), []byte("foo")) + assert.NoError(t, err) + assert.Equal(t, true, contain) + err = list.Append(context.TODO(), []byte("bar")) assert.NoError(t, err) diff --git a/pkg/client/list/slice.go b/pkg/client/list/slice.go index 992e75c..a9b3132 100644 --- a/pkg/client/list/slice.go +++ b/pkg/client/list/slice.go @@ -17,6 +17,7 @@ package list import ( "context" "errors" + "github.com/atomix/go-client/pkg/client/primitive" ) @@ -96,6 +97,11 @@ func (l *slicedList) Len(ctx context.Context) (int, error) { return size, nil } +func (l *slicedList) Contains(ctx context.Context, value []byte) (bool, error) { + contain, err := l.list.Contains(ctx, value) + return contain, err +} + func (l *slicedList) Slice(ctx context.Context, from int, to int) (List, error) { if l.from != nil { from += *l.from