[go: nahoru, domu]

blob: 3789ba8be777c3e43b5c7b5b86c62a7977be2932 [file] [log] [blame]
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "mojo/public/cpp/platform/platform_channel_server.h"
#include <tuple>
#include <utility>
#include "base/containers/span.h"
#include "base/files/scoped_temp_dir.h"
#include "base/functional/callback.h"
#include "base/run_loop.h"
#include "base/task/single_thread_task_runner.h"
#include "base/test/bind.h"
#include "base/test/task_environment.h"
#include "build/build_config.h"
#include "mojo/core/channel.h"
#include "mojo/public/cpp/platform/named_platform_channel.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/abseil-cpp/absl/types/optional.h"
namespace mojo {
namespace {
class RunOnDestruction {
public:
template <typename Fn>
explicit RunOnDestruction(Fn fn)
: callback_(base::BindLambdaForTesting(fn)) {}
RunOnDestruction(RunOnDestruction&&) = default;
~RunOnDestruction() {
if (callback_) {
std::move(callback_).Run();
}
}
private:
base::OnceClosure callback_;
};
class TestChannel : public core::Channel::Delegate {
public:
explicit TestChannel(PlatformChannelEndpoint endpoint)
: io_task_runner_(base::SingleThreadTaskRunner::GetCurrentDefault()),
channel_(
core::Channel::Create(this,
core::ConnectionParams{std::move(endpoint)},
core::Channel::HandlePolicy::kRejectHandles,
io_task_runner_)) {
channel_->Start();
}
~TestChannel() override { CHECK(stopped_); }
void Stop() {
CHECK(!stopped_);
stopped_ = true;
// We pump the IO task queue after ShutDown() to ensure completion, as
// Channel implementions post a cleanup task there.
base::RunLoop shutdown_flush;
channel_->ShutDown();
io_task_runner_->PostTask(FROM_HERE, shutdown_flush.QuitClosure());
shutdown_flush.Run();
}
void SendMessage(const std::string& message) {
auto data = base::make_span(
reinterpret_cast<const uint8_t*>(message.data()), message.size());
channel_->Write(core::Channel::Message::CreateIpczMessage(data, {}));
}
std::string WaitForSingleMessage() {
wait_for_message_.Run();
CHECK(received_message_);
return *received_message_;
}
// core::Channel::Delegate:
bool IsIpczTransport() const override {
// We use Channel in ipcz mode because it's simpler. Doesn't matter if
// MojoIpcz is actually enabled.
return true;
}
void OnChannelMessage(const void* payload,
size_t payload_size,
std::vector<PlatformHandle> handles) override {
received_message_ =
std::string(static_cast<const char*>(payload), payload_size);
std::move(quit_).Run();
}
void OnChannelError(core::Channel::Error error) override {}
private:
const scoped_refptr<base::SingleThreadTaskRunner> io_task_runner_;
const scoped_refptr<core::Channel> channel_;
base::RunLoop wait_for_message_;
base::OnceClosure quit_{wait_for_message_.QuitClosure()};
absl::optional<std::string> received_message_;
bool stopped_ = false;
};
class PlatformChannelServerTest : public testing::Test {
public:
PlatformChannelServerTest() { CHECK(temp_dir_.CreateUniqueTempDir()); }
~PlatformChannelServerTest() override = default;
using NamedChannelDetails = std::tuple<PlatformChannelServerEndpoint,
NamedPlatformChannel::ServerName>;
NamedChannelDetails CreateRandomChannel() {
NamedPlatformChannel::Options options;
#if BUILDFLAG(IS_POSIX)
options.socket_dir = temp_dir_.GetPath();
#endif
NamedPlatformChannel channel(options);
return {channel.TakeServerEndpoint(), channel.GetServerName()};
}
PlatformChannelServer& server() { return server_; }
void VerifyEndToEndConnection(PlatformChannelEndpoint a,
PlatformChannelEndpoint b) {
TestChannel channel_a(std::move(a));
TestChannel channel_b(std::move(b));
const std::string kMessage1 = "Hello, world?";
const std::string kMessage2 = "Oh, hi world.";
channel_a.SendMessage(kMessage1);
channel_b.SendMessage(kMessage2);
EXPECT_EQ(kMessage2, channel_a.WaitForSingleMessage());
EXPECT_EQ(kMessage1, channel_b.WaitForSingleMessage());
channel_a.Stop();
channel_b.Stop();
}
private:
base::ScopedTempDir temp_dir_;
base::test::TaskEnvironment task_environment_{
base::test::TaskEnvironment::MainThreadType::IO};
PlatformChannelServer server_;
};
TEST_F(PlatformChannelServerTest, ConnectAfterListen) {
// Basic test that a client can connect after Listen() and the server will
// invoke its user-provided callback with a new functioning endpoint.
auto [server_endpoint, name] = CreateRandomChannel();
base::RunLoop loop;
PlatformChannelEndpoint endpoint_a;
EXPECT_TRUE(server().Listen(
std::move(server_endpoint),
base::BindLambdaForTesting([&](PlatformChannelEndpoint endpoint) {
endpoint_a = std::move(endpoint);
loop.Quit();
})));
auto endpoint_b = NamedPlatformChannel::ConnectToServer(name);
EXPECT_TRUE(endpoint_b.is_valid());
loop.Run();
VerifyEndToEndConnection(std::move(endpoint_a), std::move(endpoint_b));
}
TEST_F(PlatformChannelServerTest, ConnectBeforeListen) {
// Basic test that a client can connect *before* Listen() and the server will
// still invoke its user-provided callback with a new functioning endpoint.
auto [server_endpoint, name] = CreateRandomChannel();
base::RunLoop loop;
auto endpoint_a = NamedPlatformChannel::ConnectToServer(name);
EXPECT_TRUE(endpoint_a.is_valid());
PlatformChannelEndpoint endpoint_b;
EXPECT_TRUE(server().Listen(
std::move(server_endpoint),
base::BindLambdaForTesting([&](PlatformChannelEndpoint endpoint) {
endpoint_b = std::move(endpoint);
loop.Quit();
})));
loop.Run();
VerifyEndToEndConnection(std::move(endpoint_a), std::move(endpoint_b));
}
TEST_F(PlatformChannelServerTest, WaitForConnection) {
// Tests the static WaitForConnection() helper.
auto [server_endpoint, name] = CreateRandomChannel();
base::RunLoop loop;
auto endpoint_a = NamedPlatformChannel::ConnectToServer(name);
PlatformChannelEndpoint endpoint_b;
PlatformChannelServer::WaitForConnection(
std::move(server_endpoint),
base::BindLambdaForTesting([&](PlatformChannelEndpoint endpoint) {
endpoint_b = std::move(endpoint);
loop.Quit();
}));
loop.Run();
VerifyEndToEndConnection(std::move(endpoint_a), std::move(endpoint_b));
}
TEST_F(PlatformChannelServerTest, NoCallbackAfterListenConnectStop) {
// Tests that the ConnectionCallback is never invoked after Stop(), even if
// we Listen() and the client connects immediately before the Stop() call.
auto [server_endpoint, name] = CreateRandomChannel();
bool callback_invoked = false;
bool callback_destroyed = false;
base::RunLoop loop;
EXPECT_TRUE(server().Listen(
std::move(server_endpoint),
base::BindOnce(
// This callback should never run and should be destroyed when we
// Stop() below.
[](RunOnDestruction, bool* callback_invoked,
PlatformChannelEndpoint endpoint) { *callback_invoked = true; },
// When the above callback is destroyed, this one will run.
RunOnDestruction([&] {
callback_destroyed = true;
loop.Quit();
}),
&callback_invoked)));
auto endpoint = NamedPlatformChannel::ConnectToServer(name);
server().Stop();
loop.Run();
EXPECT_TRUE(callback_destroyed);
EXPECT_FALSE(callback_invoked);
EXPECT_TRUE(endpoint.is_valid());
}
TEST_F(PlatformChannelServerTest, NoCallbackAfterConnectListenStop) {
// Tests that the ConnectionCallback is never invoked after Stop(), even if
// the client connects before a Listen() which immediately precedes the Stop()
// call.
auto [server_endpoint, name] = CreateRandomChannel();
bool callback_invoked = false;
bool callback_destroyed = false;
base::RunLoop loop;
auto endpoint = NamedPlatformChannel::ConnectToServer(name);
EXPECT_TRUE(endpoint.is_valid());
EXPECT_TRUE(server().Listen(
std::move(server_endpoint),
base::BindOnce(
// This callback should never run and should be destroyed when we
// Stop() below.
[](RunOnDestruction, bool* callback_invoked,
PlatformChannelEndpoint endpoint) { *callback_invoked = true; },
// When the above callback is destroyed, this one will run.
RunOnDestruction([&] {
callback_destroyed = true;
loop.Quit();
}),
&callback_invoked)));
server().Stop();
loop.Run();
EXPECT_TRUE(callback_destroyed);
EXPECT_FALSE(callback_invoked);
}
} // namespace
} // namespace mojo