Adds Channel-associated interface support on ChannelProxy's thread
This exposes a way for ChannelProxy users to associate interface
factories with a ChannelProxy, allowing requests to be bound either
on the proxy thread or directly on the IPC thread.
BUG=612500
Review-Url: https://codereview.chromium.org/2147493006
Cr-Commit-Position: refs/heads/master@{#405504}
diff --git a/ipc/ipc_channel.h b/ipc/ipc_channel.h
index 1fc9c6c..33f4d19 100644
--- a/ipc/ipc_channel.h
+++ b/ipc/ipc_channel.h
@@ -13,7 +13,9 @@
#include "base/compiler_specific.h"
#include "base/files/scoped_file.h"
+#include "base/memory/ref_counted.h"
#include "base/process/process.h"
+#include "base/single_thread_task_runner.h"
#include "build/build_config.h"
#include "ipc/ipc_channel_handle.h"
#include "ipc/ipc_endpoint.h"
@@ -120,6 +122,11 @@
const std::string& name,
mojo::ScopedInterfaceEndpointHandle handle) = 0;
+ // Sets the TaskRunner on which to support proxied dispatch for associated
+ // interfaces.
+ virtual void SetProxyTaskRunner(
+ scoped_refptr<base::SingleThreadTaskRunner> task_runner) = 0;
+
// Template helper to add an interface factory to this channel.
template <typename Interface>
using AssociatedInterfaceFactory =
diff --git a/ipc/ipc_channel_mojo.cc b/ipc/ipc_channel_mojo.cc
index 70a12f4..98fde28 100644
--- a/ipc/ipc_channel_mojo.cc
+++ b/ipc/ipc_channel_mojo.cc
@@ -498,4 +498,10 @@
message_reader_->GetRemoteInterface(name, std::move(handle));
}
+void ChannelMojo::SetProxyTaskRunner(
+ scoped_refptr<base::SingleThreadTaskRunner> task_runner) {
+ DCHECK(bootstrap_);
+ bootstrap_->SetProxyTaskRunner(task_runner);
+}
+
} // namespace IPC
diff --git a/ipc/ipc_channel_mojo.h b/ipc/ipc_channel_mojo.h
index c267649..7a35ee7 100644
--- a/ipc/ipc_channel_mojo.h
+++ b/ipc/ipc_channel_mojo.h
@@ -111,6 +111,8 @@
void GetGenericRemoteAssociatedInterface(
const std::string& name,
mojo::ScopedInterfaceEndpointHandle handle) override;
+ void SetProxyTaskRunner(
+ scoped_refptr<base::SingleThreadTaskRunner> task_runner) override;
// ChannelMojo needs to kill its MessagePipeReader in delayed manner
// because the channel wants to kill these readers during the
diff --git a/ipc/ipc_channel_mojo_unittest.cc b/ipc/ipc_channel_mojo_unittest.cc
index e846c13..0e08516 100644
--- a/ipc/ipc_channel_mojo_unittest.cc
+++ b/ipc/ipc_channel_mojo_unittest.cc
@@ -11,9 +11,12 @@
#include <utility>
#include "base/base_paths.h"
+#include "base/bind.h"
#include "base/files/file.h"
#include "base/files/scoped_temp_dir.h"
#include "base/location.h"
+#include "base/macros.h"
+#include "base/message_loop/message_loop.h"
#include "base/path_service.h"
#include "base/pickle.h"
#include "base/run_loop.h"
@@ -103,6 +106,7 @@
void Init(mojo::ScopedMessagePipeHandle handle) {
handle_ = std::move(handle);
}
+
void Connect(IPC::Listener* listener) {
channel_ = IPC::ChannelMojo::Create(std::move(handle_),
IPC::Channel::MODE_CLIENT, listener);
@@ -126,34 +130,40 @@
std::unique_ptr<IPC::ChannelMojo> channel_;
};
-class IPCChannelMojoTest : public testing::Test {
+class IPCChannelMojoTestBase : public testing::Test {
public:
- IPCChannelMojoTest() {}
-
- void TearDown() override { base::RunLoop().RunUntilIdle(); }
-
void InitWithMojo(const std::string& test_client_name) {
handle_ = helper_.StartChild(test_client_name);
}
+ bool WaitForClientShutdown() { return helper_.WaitForChildTestShutdown(); }
+
+ protected:
+ mojo::ScopedMessagePipeHandle TakeHandle() { return std::move(handle_); }
+
+ private:
+ mojo::ScopedMessagePipeHandle handle_;
+ mojo::edk::test::MultiprocessTestHelper helper_;
+};
+
+class IPCChannelMojoTest : public IPCChannelMojoTestBase {
+ public:
+ void TearDown() override { base::RunLoop().RunUntilIdle(); }
+
void CreateChannel(IPC::Listener* listener) {
- channel_ = IPC::ChannelMojo::Create(std::move(handle_),
- IPC::Channel::MODE_SERVER, listener);
+ channel_ = IPC::ChannelMojo::Create(
+ TakeHandle(), IPC::Channel::MODE_SERVER, listener);
}
bool ConnectChannel() { return channel_->Connect(); }
void DestroyChannel() { channel_.reset(); }
- bool WaitForClientShutdown() { return helper_.WaitForChildTestShutdown(); }
-
IPC::Sender* sender() { return channel(); }
IPC::Channel* channel() { return channel_.get(); }
private:
base::MessageLoop message_loop_;
- mojo::edk::test::MultiprocessTestHelper helper_;
- mojo::ScopedMessagePipeHandle handle_;
std::unique_ptr<IPC::Channel> channel_;
};
@@ -699,7 +709,176 @@
Close();
}
+class ChannelProxyRunner {
+ public:
+ ChannelProxyRunner(std::unique_ptr<IPC::ChannelFactory> channel_factory)
+ : channel_factory_(std::move(channel_factory)),
+ io_thread_("ChannelProxyRunner IO thread") {
+ }
+
+ void CreateProxy(IPC::Listener* listener) {
+ io_thread_.StartWithOptions(
+ base::Thread::Options(base::MessageLoop::TYPE_IO, 0));
+ proxy_.reset(new IPC::ChannelProxy(listener, io_thread_.task_runner()));
+ }
+ void RunProxy() { proxy_->Init(std::move(channel_factory_), true); }
+
+ IPC::ChannelProxy* proxy() { return proxy_.get(); }
+
+ private:
+ std::unique_ptr<IPC::ChannelFactory> channel_factory_;
+
+ base::Thread io_thread_;
+ std::unique_ptr<IPC::ChannelProxy> proxy_;
+
+ DISALLOW_COPY_AND_ASSIGN(ChannelProxyRunner);
+};
+
+class IPCChannelProxyMojoTest : public IPCChannelMojoTestBase {
+ public:
+ void InitWithMojo(const std::string& client_name) {
+ IPCChannelMojoTestBase::InitWithMojo(client_name);
+ runner_.reset(new ChannelProxyRunner(
+ IPC::ChannelMojo::CreateServerFactory(TakeHandle())));
+ }
+ void CreateProxy(IPC::Listener* listener) { runner_->CreateProxy(listener); }
+ void RunProxy() { runner_->RunProxy(); }
+
+ IPC::ChannelProxy* proxy() { return runner_->proxy(); }
+
+ private:
+ base::MessageLoop message_loop_;
+ std::unique_ptr<ChannelProxyRunner> runner_;
+};
+
+class ListenerWithSimpleProxyAssociatedInterface
+ : public IPC::Listener,
+ public IPC::mojom::SimpleTestDriver {
+ public:
+ static const int kNumMessages;
+
+ ListenerWithSimpleProxyAssociatedInterface() : binding_(this) {}
+
+ ~ListenerWithSimpleProxyAssociatedInterface() override {}
+
+ bool OnMessageReceived(const IPC::Message& message) override {
+ base::PickleIterator iter(message);
+ std::string should_be_expected;
+ EXPECT_TRUE(iter.ReadString(&should_be_expected));
+ EXPECT_EQ(should_be_expected, next_expected_string_);
+ num_messages_received_++;
+ return true;
+ }
+
+ void OnChannelError() override {
+ DCHECK(received_quit_);
+ }
+
+ void RegisterInterfaceFactory(IPC::ChannelProxy* proxy) {
+ proxy->AddAssociatedInterface(
+ base::Bind(&ListenerWithSimpleProxyAssociatedInterface::BindRequest,
+ base::Unretained(this)));
+ }
+
+ bool received_all_messages() const {
+ return num_messages_received_ == kNumMessages && received_quit_;
+ }
+
+ private:
+ // IPC::mojom::SimpleTestDriver:
+ void ExpectString(const mojo::String& str) override {
+ next_expected_string_ = str;
+ }
+
+ void RequestQuit(const RequestQuitCallback& callback) override {
+ received_quit_ = true;
+ callback.Run();
+ base::MessageLoop::current()->QuitWhenIdle();
+ }
+
+ void BindRequest(IPC::mojom::SimpleTestDriverAssociatedRequest request) {
+ DCHECK(!binding_.is_bound());
+ binding_.Bind(std::move(request));
+ }
+
+ std::string next_expected_string_;
+ int num_messages_received_ = 0;
+ bool received_quit_ = false;
+
+ mojo::AssociatedBinding<IPC::mojom::SimpleTestDriver> binding_;
+};
+
+const int ListenerWithSimpleProxyAssociatedInterface::kNumMessages = 1000;
+
+TEST_F(IPCChannelProxyMojoTest, ProxyThreadAssociatedInterface) {
+ InitWithMojo("ProxyThreadAssociatedInterfaceClient");
+
+ ListenerWithSimpleProxyAssociatedInterface listener;
+ CreateProxy(&listener);
+ listener.RegisterInterfaceFactory(proxy());
+ RunProxy();
+
+ base::RunLoop().Run();
+
+ EXPECT_TRUE(WaitForClientShutdown());
+ EXPECT_TRUE(listener.received_all_messages());
+
+ base::RunLoop().RunUntilIdle();
+}
+
+class ChannelProxyClient {
+ public:
+ void Init(mojo::ScopedMessagePipeHandle handle) {
+ runner_.reset(new ChannelProxyRunner(
+ IPC::ChannelMojo::CreateClientFactory(std::move(handle))));
+ }
+ void CreateProxy(IPC::Listener* listener) { runner_->CreateProxy(listener); }
+ void RunProxy() { runner_->RunProxy(); }
+
+ IPC::ChannelProxy* proxy() { return runner_->proxy(); }
+
+ private:
+ base::MessageLoop message_loop_;
+ std::unique_ptr<ChannelProxyRunner> runner_;
+};
+
+class ListenerThatWaitsForConnect : public IPC::Listener {
+ public:
+ explicit ListenerThatWaitsForConnect(const base::Closure& connect_handler)
+ : connect_handler_(connect_handler) {}
+
+ // IPC::Listener
+ bool OnMessageReceived(const IPC::Message& message) override { return true; }
+ void OnChannelConnected(int32_t) override { connect_handler_.Run(); }
+
+ private:
+ base::Closure connect_handler_;
+};
+
+DEFINE_IPC_CHANNEL_MOJO_TEST_CLIENT(ProxyThreadAssociatedInterfaceClient,
+ ChannelProxyClient) {
+ base::RunLoop connect_loop;
+ ListenerThatWaitsForConnect listener(connect_loop.QuitClosure());
+ CreateProxy(&listener);
+ RunProxy();
+ connect_loop.Run();
+
+ // Send a bunch of interleaved messages, alternating between the associated
+ // interface and a legacy IPC::Message.
+ IPC::mojom::SimpleTestDriverAssociatedPtr driver;
+ proxy()->GetRemoteAssociatedInterface(&driver);
+ for (int i = 0; i < ListenerWithSimpleProxyAssociatedInterface::kNumMessages;
+ ++i) {
+ std::string str = base::StringPrintf("Hello! %d", i);
+ driver->ExpectString(str);
+ SendString(proxy(), str);
+ }
+ driver->RequestQuit(base::MessageLoop::QuitWhenIdleClosure());
+ base::RunLoop().Run();
+}
+
#if defined(OS_POSIX)
+
class ListenerThatExpectsFile : public IPC::Listener {
public:
ListenerThatExpectsFile() : sender_(NULL) {}
diff --git a/ipc/ipc_channel_proxy.cc b/ipc/ipc_channel_proxy.cc
index b342412..5ac51cd 100644
--- a/ipc/ipc_channel_proxy.cc
+++ b/ipc/ipc_channel_proxy.cc
@@ -27,6 +27,17 @@
namespace IPC {
+namespace {
+
+void BindAssociatedInterfaceOnTaskRunner(
+ scoped_refptr<base::SingleThreadTaskRunner> task_runner,
+ const ChannelProxy::GenericAssociatedInterfaceFactory& factory,
+ mojo::ScopedInterfaceEndpointHandle handle) {
+ task_runner->PostTask(FROM_HERE, base::Bind(factory, base::Passed(&handle)));
+}
+
+} // namespace
+
//------------------------------------------------------------------------------
ChannelProxy::Context::Context(
@@ -147,6 +158,23 @@
for (size_t i = 0; i < filters_.size(); ++i)
filters_[i]->OnFilterAdded(channel_.get());
+
+ Channel::AssociatedInterfaceSupport* support =
+ channel_->GetAssociatedInterfaceSupport();
+ if (support) {
+ support->SetProxyTaskRunner(listener_task_runner_);
+ for (auto& entry : io_thread_interfaces_)
+ support->AddGenericAssociatedInterface(entry.first, entry.second);
+ for (auto& entry : proxy_thread_interfaces_) {
+ support->AddGenericAssociatedInterface(
+ entry.first, base::Bind(&BindAssociatedInterfaceOnTaskRunner,
+ listener_task_runner_, entry.second));
+ }
+ } else {
+ // Sanity check to ensure nobody's expecting to use associated interfaces on
+ // a Channel that doesn't support them.
+ DCHECK(io_thread_interfaces_.empty() && proxy_thread_interfaces_.empty());
+ }
}
// Called on the IPC::Channel thread
@@ -296,6 +324,15 @@
if (channel_connected_called_)
return;
+ if (channel_) {
+ Channel::AssociatedInterfaceSupport* associated_interface_support =
+ channel_->GetAssociatedInterfaceSupport();
+ if (associated_interface_support) {
+ channel_associated_group_.reset(new mojo::AssociatedGroup(
+ *associated_interface_support->GetAssociatedGroup()));
+ }
+ }
+
channel_connected_called_ = true;
if (listener_)
listener_->OnChannelConnected(peer_pid_);
@@ -341,6 +378,19 @@
return channel_send_thread_safe_;
}
+// Called on the IPC::Channel thread
+void ChannelProxy::Context::GetRemoteAssociatedInterface(
+ const std::string& name,
+ mojo::ScopedInterfaceEndpointHandle handle) {
+ if (!channel_)
+ return;
+ Channel::AssociatedInterfaceSupport* associated_interface_support =
+ channel_->GetAssociatedInterfaceSupport();
+ DCHECK(associated_interface_support);
+ associated_interface_support->GetGenericRemoteAssociatedInterface(
+ name, std::move(handle));
+}
+
//-----------------------------------------------------------------------------
// static
@@ -479,6 +529,34 @@
base::RetainedRef(filter)));
}
+void ChannelProxy::AddGenericAssociatedInterfaceForIOThread(
+ const std::string& name,
+ const GenericAssociatedInterfaceFactory& factory) {
+ DCHECK(CalledOnValidThread());
+ DCHECK(!did_init_);
+ context_->io_thread_interfaces_.insert({ name, factory });
+}
+
+void ChannelProxy::AddGenericAssociatedInterface(
+ const std::string& name,
+ const GenericAssociatedInterfaceFactory& factory) {
+ DCHECK(CalledOnValidThread());
+ DCHECK(!did_init_);
+ context_->proxy_thread_interfaces_.insert({ name, factory });
+}
+
+mojo::AssociatedGroup* ChannelProxy::GetAssociatedGroup() {
+ return context_->channel_associated_group_.get();
+}
+
+void ChannelProxy::GetGenericRemoteAssociatedInterface(
+ const std::string& name,
+ mojo::ScopedInterfaceEndpointHandle handle) {
+ context_->ipc_task_runner()->PostTask(
+ FROM_HERE, base::Bind(&Context::GetRemoteAssociatedInterface,
+ context_.get(), name, base::Passed(&handle)));
+}
+
void ChannelProxy::ClearIPCTaskRunner() {
DCHECK(CalledOnValidThread());
diff --git a/ipc/ipc_channel_proxy.h b/ipc/ipc_channel_proxy.h
index 0c93233..eeb0468 100644
--- a/ipc/ipc_channel_proxy.h
+++ b/ipc/ipc_channel_proxy.h
@@ -7,9 +7,12 @@
#include <stdint.h>
+#include <map>
#include <memory>
+#include <string>
#include <vector>
+#include "base/callback.h"
#include "base/memory/ref_counted.h"
#include "base/synchronization/lock.h"
#include "base/threading/non_thread_safe.h"
@@ -19,6 +22,9 @@
#include "ipc/ipc_endpoint.h"
#include "ipc/ipc_listener.h"
#include "ipc/ipc_sender.h"
+#include "mojo/public/cpp/bindings/associated_group.h"
+#include "mojo/public/cpp/bindings/associated_interface_request.h"
+#include "mojo/public/cpp/bindings/scoped_interface_endpoint_handle.h"
namespace base {
class SingleThreadTaskRunner;
@@ -139,6 +145,69 @@
void AddFilter(MessageFilter* filter);
void RemoveFilter(MessageFilter* filter);
+ using GenericAssociatedInterfaceFactory =
+ base::Callback<void(mojo::ScopedInterfaceEndpointHandle)>;
+
+ // Adds a generic associated interface factory to bind incoming interface
+ // requests directly on the IO thread. MUST be called before Init().
+ void AddGenericAssociatedInterfaceForIOThread(
+ const std::string& name,
+ const GenericAssociatedInterfaceFactory& factory);
+
+ // Adds a generic associated interface factory to bind incoming interface
+ // requests on the ChannelProxy's thread. MUST be called before Init().
+ void AddGenericAssociatedInterface(
+ const std::string& name,
+ const GenericAssociatedInterfaceFactory& factory);
+
+ template <typename Interface>
+ using AssociatedInterfaceFactory =
+ base::Callback<void(mojo::AssociatedInterfaceRequest<Interface>)>;
+
+ // Helper to bind an IO-thread associated interface factory, inferring the
+ // interface name from the callback argument's type. MUST be called before
+ // Init().
+ template <typename Interface>
+ void AddAssociatedInterfaceForIOThread(
+ const AssociatedInterfaceFactory<Interface>& factory) {
+ AddGenericAssociatedInterfaceForIOThread(
+ Interface::Name_,
+ base::Bind(&ChannelProxy::BindAssociatedInterfaceRequest<Interface>,
+ factory));
+ }
+
+ // Helper to bind a ChannelProxy-thread associated interface factory,
+ // inferring the interface name from the callback argument's type. MUST be
+ // called before Init().
+ template <typename Interface>
+ void AddAssociatedInterface(
+ const AssociatedInterfaceFactory<Interface>& factory) {
+ AddGenericAssociatedInterface(
+ Interface::Name_,
+ base::Bind(&ChannelProxy::BindAssociatedInterfaceRequest<Interface>,
+ factory));
+ }
+
+ // Gets the AssociatedGroup used to create new associated endpoints on this
+ // ChannelProxy. This must only be called after the listener's
+ // OnChannelConnected is called.
+ mojo::AssociatedGroup* GetAssociatedGroup();
+
+ // Requests an associated interface from the remote endpoint.
+ void GetGenericRemoteAssociatedInterface(
+ const std::string& name,
+ mojo::ScopedInterfaceEndpointHandle handle);
+
+ // Template helper to request associated interfaces from the remote endpoint.
+ // Must only be called after the listener's OnChannelConnected is called.
+ template <typename Interface>
+ void GetRemoteAssociatedInterface(
+ mojo::AssociatedInterfacePtr<Interface>* proxy) {
+ mojo::AssociatedInterfaceRequest<Interface> request =
+ mojo::GetProxy(proxy, GetAssociatedGroup());
+ GetGenericRemoteAssociatedInterface(Interface::Name_, request.PassHandle());
+ }
+
#if defined(ENABLE_IPC_FUZZER)
void set_outgoing_message_filter(OutgoingMessageFilter* filter) {
outgoing_message_filter_ = filter;
@@ -186,6 +255,11 @@
// Indicates if the underlying channel's Send is thread-safe.
bool IsChannelSendThreadSafe() const;
+ // Requests a remote associated interface on the IPC thread.
+ void GetRemoteAssociatedInterface(
+ const std::string& name,
+ mojo::ScopedInterfaceEndpointHandle handle);
+
protected:
friend class base::RefCountedThreadSafe<Context>;
~Context() override;
@@ -277,6 +351,16 @@
// Whether this channel is used as an endpoint for sending and receiving
// brokerable attachment messages to/from the broker process.
bool attachment_broker_endpoint_;
+
+ // Modified only on the listener's thread before Init() is called.
+ std::map<std::string, GenericAssociatedInterfaceFactory>
+ io_thread_interfaces_;
+ std::map<std::string, GenericAssociatedInterfaceFactory>
+ proxy_thread_interfaces_;
+
+ // Valid and constant any time after the ChannelProxy's Listener receives
+ // OnChannelConnected on its own thread.
+ std::unique_ptr<mojo::AssociatedGroup> channel_associated_group_;
};
Context* context() { return context_.get(); }
@@ -293,6 +377,15 @@
private:
friend class IpcSecurityTestUtil;
+ template <typename Interface>
+ static void BindAssociatedInterfaceRequest(
+ const AssociatedInterfaceFactory<Interface>& factory,
+ mojo::ScopedInterfaceEndpointHandle handle) {
+ mojo::AssociatedInterfaceRequest<Interface> request;
+ request.Bind(std::move(handle));
+ factory.Run(std::move(request));
+ }
+
// Always called once immediately after Init.
virtual void OnChannelInit();
diff --git a/ipc/ipc_mojo_bootstrap.cc b/ipc/ipc_mojo_bootstrap.cc
index 425f7948..fc39d0d9 100644
--- a/ipc/ipc_mojo_bootstrap.cc
+++ b/ipc/ipc_mojo_bootstrap.cc
@@ -121,7 +121,7 @@
if (!is_local) {
DCHECK(ContainsKey(endpoints_, id));
DCHECK(!mojo::IsMasterInterfaceId(id));
- control_message_proxy_.NotifyEndpointClosedBeforeSent(id);
+ NotifyEndpointClosedBeforeSent(id);
return;
}
@@ -132,7 +132,7 @@
MarkClosedAndMaybeRemove(endpoint);
if (!mojo::IsMasterInterfaceId(id))
- control_message_proxy_.NotifyPeerEndpointClosed(id);
+ NotifyPeerEndpointClosed(id);
}
mojo::InterfaceEndpointController* AttachEndpointClient(
@@ -392,6 +392,30 @@
endpoints_.erase(endpoint->id());
}
+ void NotifyPeerEndpointClosed(mojo::InterfaceId id) {
+ if (task_runner_->BelongsToCurrentThread()) {
+ if (connector_.is_valid())
+ control_message_proxy_.NotifyPeerEndpointClosed(id);
+ } else {
+ task_runner_->PostTask(
+ FROM_HERE,
+ base::Bind(&ChannelAssociatedGroupController
+ ::NotifyPeerEndpointClosed, this, id));
+ }
+ }
+
+ void NotifyEndpointClosedBeforeSent(mojo::InterfaceId id) {
+ if (task_runner_->BelongsToCurrentThread()) {
+ if (connector_.is_valid())
+ control_message_proxy_.NotifyEndpointClosedBeforeSent(id);
+ } else {
+ task_runner_->PostTask(
+ FROM_HERE,
+ base::Bind(&ChannelAssociatedGroupController
+ ::NotifyEndpointClosedBeforeSent, this, id));
+ }
+ }
+
Endpoint* FindOrInsertEndpoint(mojo::InterfaceId id, bool* inserted) {
lock_.AssertAcquired();
DCHECK(!inserted || !*inserted);
@@ -411,26 +435,15 @@
bool Accept(mojo::Message* message) override {
DCHECK(thread_checker_.CalledOnValidThread());
- if (mojo::PipeControlMessageHandler::IsPipeControlMessage(message)) {
- if (!control_message_handler_.Accept(message))
- RaiseError();
- return true;
- }
+ if (mojo::PipeControlMessageHandler::IsPipeControlMessage(message))
+ return control_message_handler_.Accept(message);
mojo::InterfaceId id = message->interface_id();
DCHECK(mojo::IsValidInterfaceId(id));
base::AutoLock locker(lock_);
- bool inserted = false;
- Endpoint* endpoint = FindOrInsertEndpoint(id, &inserted);
- if (inserted) {
- MarkClosedAndMaybeRemove(endpoint);
- if (!mojo::IsMasterInterfaceId(id))
- control_message_proxy_.NotifyPeerEndpointClosed(id);
- return true;
- }
-
- if (endpoint->closed())
+ Endpoint* endpoint = GetEndpointForDispatch(id);
+ if (!endpoint)
return true;
mojo::InterfaceEndpointClient* client = endpoint->client();
@@ -442,7 +455,6 @@
// If the client is not yet bound, it must be bound by the time this task
// runs or else it's programmer error.
DCHECK(proxy_task_runner_);
- CHECK(false);
std::unique_ptr<mojo::Message> passed_message(new mojo::Message);
message->MoveTo(passed_message.get());
proxy_task_runner_->PostTask(
@@ -456,23 +468,56 @@
// If it's happening, it's a bug.
DCHECK(!message->has_flag(mojo::Message::kFlagIsSync));
- bool result = false;
- {
- base::AutoUnlock unlocker(lock_);
- result = client->HandleIncomingMessage(message);
- }
-
- if (!result)
- RaiseError();
-
- return true;
+ base::AutoUnlock unlocker(lock_);
+ return client->HandleIncomingMessage(message);
}
void AcceptOnProxyThread(std::unique_ptr<mojo::Message> message) {
DCHECK(proxy_task_runner_->BelongsToCurrentThread());
- // TODO(rockot): Implement this.
- NOTREACHED();
+ mojo::InterfaceId id = message->interface_id();
+ DCHECK(mojo::IsValidInterfaceId(id) && !mojo::IsMasterInterfaceId(id));
+
+ base::AutoLock locker(lock_);
+ Endpoint* endpoint = GetEndpointForDispatch(id);
+ if (!endpoint)
+ return;
+
+ mojo::InterfaceEndpointClient* client = endpoint->client();
+ if (!client)
+ return;
+
+ DCHECK(endpoint->task_runner()->BelongsToCurrentThread());
+
+ // TODO(rockot): Implement sync dispatch. For now, sync messages are
+ // unsupported here.
+ DCHECK(!message->has_flag(mojo::Message::kFlagIsSync));
+
+ bool result = false;
+ {
+ base::AutoUnlock unlocker(lock_);
+ result = client->HandleIncomingMessage(message.get());
+ }
+
+ if (!result)
+ RaiseError();
+ }
+
+ Endpoint* GetEndpointForDispatch(mojo::InterfaceId id) {
+ lock_.AssertAcquired();
+ bool inserted = false;
+ Endpoint* endpoint = FindOrInsertEndpoint(id, &inserted);
+ if (inserted) {
+ MarkClosedAndMaybeRemove(endpoint);
+ if (!mojo::IsMasterInterfaceId(id))
+ NotifyPeerEndpointClosed(id);
+ return nullptr;
+ }
+
+ if (endpoint->closed())
+ return nullptr;
+
+ return endpoint;
}
// mojo::PipeControlMessageHandlerDelegate:
@@ -561,6 +606,11 @@
return controller_->associated_group();
}
+ ChannelAssociatedGroupController* controller() {
+ DCHECK(controller_);
+ return controller_.get();
+ }
+
mojom::Bootstrap* operator->() {
DCHECK(proxy_);
return proxy_.get();
@@ -596,6 +646,11 @@
return controller_->associated_group();
}
+ ChannelAssociatedGroupController* controller() {
+ DCHECK(controller_);
+ return controller_.get();
+ }
+
void Bind(mojo::ScopedMessagePipeHandle handle) {
DCHECK(!controller_);
controller_ =
@@ -626,10 +681,16 @@
private:
// MojoBootstrap implementation.
void Connect() override;
+
mojo::AssociatedGroup* GetAssociatedGroup() override {
return bootstrap_.associated_group();
}
+ void SetProxyTaskRunner(
+ scoped_refptr<base::SingleThreadTaskRunner> task_runner) override {
+ bootstrap_.controller()->SetProxyTaskRunner(task_runner);
+ }
+
void OnInitDone(int32_t peer_pid);
BootstrapMasterProxy bootstrap_;
@@ -688,10 +749,16 @@
private:
// MojoBootstrap implementation.
void Connect() override;
+
mojo::AssociatedGroup* GetAssociatedGroup() override {
return binding_.associated_group();
}
+ void SetProxyTaskRunner(
+ scoped_refptr<base::SingleThreadTaskRunner> task_runner) override {
+ binding_.controller()->SetProxyTaskRunner(task_runner);
+ }
+
// mojom::Bootstrap implementation.
void Init(mojom::ChannelAssociatedRequest receive_channel,
mojom::ChannelAssociatedPtrInfo send_channel,
diff --git a/ipc/ipc_mojo_bootstrap.h b/ipc/ipc_mojo_bootstrap.h
index 1d632d3..b9df4083 100644
--- a/ipc/ipc_mojo_bootstrap.h
+++ b/ipc/ipc_mojo_bootstrap.h
@@ -10,7 +10,9 @@
#include <memory>
#include "base/macros.h"
+#include "base/memory/ref_counted.h"
#include "base/process/process_handle.h"
+#include "base/single_thread_task_runner.h"
#include "build/build_config.h"
#include "ipc/ipc.mojom.h"
#include "ipc/ipc_channel.h"
@@ -54,6 +56,9 @@
virtual mojo::AssociatedGroup* GetAssociatedGroup() = 0;
+ virtual void SetProxyTaskRunner(
+ scoped_refptr<base::SingleThreadTaskRunner> task_runner) = 0;
+
// GetSelfPID returns our PID.
base::ProcessId GetSelfPID() const;