[go: nahoru, domu]

Skip to content

Commit

Permalink
Allow optional KV store overwrites.
Browse files Browse the repository at this point in the history
Cleanups: use std::string_view, migrate away from TF_ASSERT/EXPECT.

PiperOrigin-RevId: 639131697
  • Loading branch information
tensorflower-gardener committed May 31, 2024
1 parent 0445480 commit 55aaac6
Show file tree
Hide file tree
Showing 15 changed files with 316 additions and 227 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ class MockCoordinationServiceAgent : public CoordinationServiceAgent {
(override));
MOCK_METHOD(Status, InsertKeyValue,
(std::string_view key, std::string_view value), (override));
MOCK_METHOD(Status, InsertKeyValue,
(std::string_view key, std::string_view value,
bool allow_overwrite),
(override));
MOCK_METHOD(Status, DeleteKeyValue, (std::string_view key), (override));
MOCK_METHOD(Status, UpdateKeyValue,
(std::string_view key, std::string_view value), (override));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ message KeyValueEntry {
// Request and response messages for inserting configuration key-value data.
message InsertKeyValueRequest {
KeyValueEntry kv = 1;
bool allow_overwrite = 2;
}

message InsertKeyValueResponse {}
Expand Down
13 changes: 5 additions & 8 deletions third_party/xla/xla/pjrt/distributed/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -70,23 +70,17 @@ cc_library(
deps = [
":key_value_store_interface",
":util",
"//xla:statusor",
"//xla:types",
"//xla:util",
"//xla/tsl/distributed_runtime/coordination:coordination_client",
"//xla/tsl/distributed_runtime/coordination:coordination_service_agent",
"//xla/tsl/distributed_runtime/coordination:coordination_service_error_util",
"//xla/tsl/distributed_runtime/rpc/coordination:grpc_coordination_client",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
"@local_tsl//tsl/platform:env",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:logging",
"@local_tsl//tsl/platform:random",
"@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/protobuf:coordination_config_proto_cc",
"@local_tsl//tsl/protobuf:coordination_service_proto_cc",
] + tsl_grpc_cc_dependencies(),
Expand Down Expand Up @@ -150,13 +144,16 @@ xla_cc_test(
"//xla:protobuf_util",
"//xla:status_macros",
"//xla/tsl/distributed_runtime/coordination:coordination_service_agent",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
"@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:env",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/platform:test",
"@local_tsl//tsl/platform:test_main",
] + tsl_grpc_cc_dependencies(),
Expand Down
19 changes: 12 additions & 7 deletions third_party/xla/xla/pjrt/distributed/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,19 @@ limitations under the License.
#include <utility>
#include <vector>

#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "grpcpp/channel.h"
#include "xla/pjrt/distributed/key_value_store_interface.h"
#include "xla/tsl/distributed_runtime/coordination/coordination_client.h"
#include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h"
#include "xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h"
#include "xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"
#include "tsl/protobuf/coordination_config.pb.h"
#include "tsl/protobuf/coordination_service.pb.h"

Expand All @@ -59,6 +60,8 @@ class DistributedRuntimeCoordinationServiceClient
KeyValueDirGet(std::string_view key) override;
absl::Status KeyValueSet(std::string_view key,
std::string_view value) override;
absl::Status KeyValueSet(std::string_view key, std::string_view value,
bool allow_overwrite) override;
absl::Status KeyValueDelete(std::string_view key) override;
absl::Status WaitAtBarrier(
std::string barrier_id, absl::Duration timeout,
Expand Down Expand Up @@ -150,10 +153,7 @@ DistributedRuntimeCoordinationServiceClient::BlockingKeyValueGet(
absl::StatusOr<std::vector<std::pair<std::string, std::string>>>
DistributedRuntimeCoordinationServiceClient::KeyValueDirGet(
std::string_view key) {
// TODO(hanyangtay): Migrate to string_view for both client and coordination
// agent APIs.
TF_ASSIGN_OR_RETURN(const auto results,
coord_agent_->GetKeyValueDir(std::string(key)));
TF_ASSIGN_OR_RETURN(const auto results, coord_agent_->GetKeyValueDir(key));

std::vector<std::pair<std::string, std::string>> kvs;
kvs.reserve(results.size());
Expand All @@ -173,7 +173,12 @@ absl::Status DistributedRuntimeCoordinationServiceClient::KeyValueDelete(

absl::Status DistributedRuntimeCoordinationServiceClient::KeyValueSet(
std::string_view key, std::string_view value) {
return coord_agent_->InsertKeyValue(key, value);
return KeyValueSet(key, value, /*allow_overwrite=*/false);
}

absl::Status DistributedRuntimeCoordinationServiceClient::KeyValueSet(
std::string_view key, std::string_view value, bool allow_overwrite) {
return coord_agent_->InsertKeyValue(key, value, allow_overwrite);
}

absl::Status DistributedRuntimeCoordinationServiceClient::WaitAtBarrier(
Expand Down
7 changes: 5 additions & 2 deletions third_party/xla/xla/pjrt/distributed/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@ limitations under the License.
#include <utility>
#include <vector>

#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "grpcpp/channel.h"
#include "xla/pjrt/distributed/key_value_store_interface.h"
#include "xla/statusor.h"
#include "xla/types.h"
#include "tsl/platform/env.h"

namespace tsl {
Expand Down Expand Up @@ -133,6 +134,8 @@ class DistributedRuntimeClient {

virtual absl::Status KeyValueSet(std::string_view key,
std::string_view value) = 0;
virtual absl::Status KeyValueSet(std::string_view key, std::string_view value,
bool allow_overwrite) = 0;

// Delete the key-value. If the key is a directory, recursively clean
// up all key-values under the directory.
Expand Down
43 changes: 40 additions & 3 deletions third_party/xla/xla/pjrt/distributed/client_server_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,31 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <cstdint>
#include <functional>
#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <vector>

#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/barrier.h"
#include "absl/synchronization/mutex.h"
#include "absl/synchronization/notification.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "grpcpp/grpcpp.h"
#include "absl/types/span.h"
#include "grpcpp/channel.h"
#include "grpcpp/create_channel.h"
#include "grpcpp/security/credentials.h"
#include "grpcpp/security/server_credentials.h"
#include "grpcpp/server.h"
#include "grpcpp/server_builder.h"
#include "grpcpp/support/channel_arguments.h"
#include "xla/pjrt/distributed/client.h"
#include "xla/pjrt/distributed/protocol.pb.h"
#include "xla/pjrt/distributed/service.h"
Expand All @@ -39,7 +48,9 @@ limitations under the License.
#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/env.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"
#include "tsl/platform/threadpool.h"

namespace xla {
namespace {
Expand Down Expand Up @@ -407,8 +418,8 @@ TEST_F(ClientServerTest, ClientsTerminateShutdownIfAnyClientGoesAway) {
// 1. Internal: node turns into ERROR state during the shutdown call.
// 2. Failed Precondition: node is already in ERROR state before the
// shutdown call (note: agent will still stop sending heartbeats).
EXPECT_TRUE(tsl::errors::IsInternal(statuses[i]) ||
tsl::errors::IsFailedPrecondition(statuses[i]));
EXPECT_TRUE(absl::IsInternal(statuses[i]) ||
absl::IsFailedPrecondition(statuses[i]));
}
}

Expand Down Expand Up @@ -812,6 +823,32 @@ TEST_F(ClientServerTest, KeyValueDirGet) {
Pair("test_dir/3", "3")));
}

TEST_F(ClientServerTest, KeyValueSet_Duplicate_Fails) {
StartService(/*num_nodes=*/1);
auto client = GetClient(/*node_id=*/0);
TF_ASSERT_OK(client->Connect());
TF_ASSERT_OK(client->KeyValueSet("test_key", "original_value"));
EXPECT_TRUE(
absl::IsAlreadyExists(client->KeyValueSet("test_key", "never_added")));
auto result =
client->BlockingKeyValueGet("test_key", absl::Milliseconds(100));
TF_ASSERT_OK(result.status());
EXPECT_EQ(result.value(), "original_value");
}

TEST_F(ClientServerTest, KeyValueSet_Duplicate_Overwrites) {
StartService(/*num_nodes=*/1);
auto client = GetClient(/*node_id=*/0);
TF_ASSERT_OK(client->Connect());
TF_ASSERT_OK(client->KeyValueSet("test_key", "original_value"));
TF_EXPECT_OK(client->KeyValueSet("test_key", "overwritten_value",
/*allow_overwrite=*/true));
auto result =
client->BlockingKeyValueGet("test_key", absl::Milliseconds(100));
TF_ASSERT_OK(result.status());
EXPECT_EQ(result.value(), "overwritten_value");
}

TEST_F(ClientServerTest, KeyValueDelete) {
StartService(/*num_nodes=*/1);
auto client = GetClient(/*node_id=*/0);
Expand Down
20 changes: 11 additions & 9 deletions third_party/xla/xla/python/xla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -660,32 +660,34 @@ NB_MODULE(xla_extension, m_nb) {
.def(
"key_value_set",
[](DistributedRuntimeClient& client, std::string_view key,
std::string_view value) {
std::string_view value, bool allow_overwrite) {
nb::gil_scoped_release gil_release;
xla::ThrowIfError(client.KeyValueSet(key, value));
xla::ThrowIfError(client.KeyValueSet(key, value, allow_overwrite));
},
nb::arg("key"), nb::arg("value"))
nb::arg("key"), nb::arg("value"), nb::arg("allow_overwrite") = false)
.def(
"key_value_set",
[](DistributedRuntimeClient& client, std::string_view key,
nb::bytes value) {
nb::bytes value, bool allow_overwrite) {
nb::gil_scoped_release gil_release;
xla::ThrowIfError(client.KeyValueSet(
key, std::string_view(value.c_str(), value.size())));
key, std::string_view(value.c_str(), value.size()),
allow_overwrite));
},
nb::arg("key"), nb::arg("value"))
nb::arg("key"), nb::arg("value"), nb::arg("allow_overwrite") = false)
// The key must be a string, but the value must a
// Python bytes object.
// Use `key_value_set_bytes()` and `blocking_key_value_get_bytes()`.
.def(
"key_value_set_bytes",
[](DistributedRuntimeClient& client, std::string_view key,
nb::bytes value) {
nb::bytes value, bool allow_overwrite) {
nb::gil_scoped_release gil_release;
xla::ThrowIfError(client.KeyValueSet(
key, std::string_view(value.c_str(), value.size())));
key, std::string_view(value.c_str(), value.size()),
allow_overwrite));
},
nb::arg("key"), nb::arg("value"))
nb::arg("key"), nb::arg("value"), nb::arg("allow_overwrite") = false)
// Assumes that all values in the directory are Python strings.
.def(
"key_value_dir_get",
Expand Down
6 changes: 4 additions & 2 deletions third_party/xla/xla/python/xla_extension/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -775,8 +775,10 @@ class DistributedRuntimeClient:
) -> _Status: ...
def key_value_dir_get(self, key: str) -> _Status: ...
def key_value_dir_get_bytes(self, key: str) -> _Status: ...
def key_value_set(self, key: str, value: str) -> _Status: ...
def key_value_set_bytes(self, key: str, value: bytes) -> _Status: ...
def key_value_set(self, key: str, value: str,
allow_overwrite: bool = False) -> _Status: ...
def key_value_set_bytes(self, key: str, value: bytes,
allow_overwrite: bool = False) -> _Status: ...
def key_value_delete(self, key: str) -> _Status: ...
def wait_at_barrier(
self, barrier_id: str, timeout_in_ms: int, process_ids: Optional[List[int]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ tsl_cc_test(
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:env",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:random",
Expand Down
Loading

0 comments on commit 55aaac6

Please sign in to comment.