From 1313b6b879fb60157ca4ba7f4a92c1c3e9bf59ae Mon Sep 17 00:00:00 2001 From: Galen Warren Date: Sun, 23 Oct 2022 17:45:18 -0400 Subject: [PATCH] feat(firestore): add GetCommitTime TransactionOption --- firestore/transaction.go | 33 ++++++++++++++++++++++++++++++--- firestore/transaction_test.go | 8 +++++++- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/firestore/transaction.go b/firestore/transaction.go index 8f766b621afc..10f1bb33217e 100644 --- a/firestore/transaction.go +++ b/firestore/transaction.go @@ -17,6 +17,7 @@ package firestore import ( "context" "errors" + "time" "cloud.google.com/go/internal/trace" gax "github.com/googleapis/gax-go/v2" @@ -40,6 +41,7 @@ type Transaction struct { // A TransactionOption is an option passed to Client.Transaction. type TransactionOption interface { config(t *Transaction) + handleCommitResponse(r *pb.CommitResponse) } // MaxAttempts is a TransactionOption that configures the maximum number of times to @@ -48,7 +50,8 @@ func MaxAttempts(n int) maxAttempts { return maxAttempts(n) } type maxAttempts int -func (m maxAttempts) config(t *Transaction) { t.maxAttempts = int(m) } +func (m maxAttempts) config(t *Transaction) { t.maxAttempts = int(m) } +func (m maxAttempts) handleCommitResponse(r *pb.CommitResponse) {} // DefaultTransactionMaxAttempts is the default number of times to attempt a transaction. const DefaultTransactionMaxAttempts = 5 @@ -59,7 +62,23 @@ var ReadOnly = ro{} type ro struct{} -func (ro) config(t *Transaction) { t.readOnly = true } +func (ro) config(t *Transaction) { t.readOnly = true } +func (ro) handleCommitResponse(r *pb.CommitResponse) {} + +// GetCommitTime is a TransactionOption that allows the caller to indicate where the commit +// time for the transaction should be stored, upon successful commit. +func GetCommitTime(t *time.Time) commitTime { + return commitTime{Time: t} +} + +type commitTime struct { + *time.Time +} + +func (c commitTime) config(t *Transaction) {} +func (c commitTime) handleCommitResponse(r *pb.CommitResponse) { + *c.Time = r.CommitTime.AsTime() +} var ( // Defined here for testing. @@ -114,6 +133,7 @@ func (c *Client) RunTransaction(ctx context.Context, f func(context.Context, *Tr } } var backoff gax.Backoff + var commitResponse *pb.CommitResponse // TODO(jba): use other than the standard backoff parameters? // TODO(jba): get backoff time from gRPC trailer metadata? See // extractRetryDelay in https://code.googlesource.com/gocloud/+/master/spanner/retry.go. @@ -141,13 +161,20 @@ func (c *Client) RunTransaction(ctx context.Context, f func(context.Context, *Tr return err } t.ctx = trace.StartSpan(t.ctx, "cloud.google.com/go/firestore.Client.Commit") - _, err = t.c.c.Commit(t.ctx, &pb.CommitRequest{ + commitResponse, err = t.c.c.Commit(t.ctx, &pb.CommitRequest{ Database: t.c.path(), Writes: t.writes, Transaction: t.id, }) trace.EndSpan(t.ctx, err) + // on success, handle the commit response + if err == nil { + for _, opt := range opts { + opt.handleCommitResponse(commitResponse) + } + } + // If a read-write transaction returns Aborted, retry. // On success or other failures, return here. if t.readOnly || status.Code(err) != codes.Aborted { diff --git a/firestore/transaction_test.go b/firestore/transaction_test.go index 2aeb5794ba65..c0e27e66ab48 100644 --- a/firestore/transaction_test.go +++ b/firestore/transaction_test.go @@ -84,6 +84,7 @@ func TestRunTransaction(t *testing.T) { }, &pb.CommitResponse{CommitTime: aTimestamp3}, ) + var commitTime time.Time err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error { docref := c.Collection("C").Doc("a") doc, err := tx.Get(docref) @@ -95,11 +96,16 @@ func TestRunTransaction(t *testing.T) { return err } return tx.Update(docref, []Update{{Path: "count", Value: count.(int64) + 1}}) - }) + }, GetCommitTime(&commitTime)) if err != nil { t.Fatal(err) } + // validate commit time + if commitTime != aTimestamp3.AsTime() { + t.Fatalf("commit time %v should equal %v", commitTime, aTimestamp3) + } + // Query srv.reset() srv.addRPC(beginReq, beginRes)