[go: nahoru, domu]

Adds Channel-associated interface support on ChannelProxy's thread

This exposes a way for ChannelProxy users to associate interface
factories with a ChannelProxy, allowing requests to be bound either
on the proxy thread or directly on the IPC thread.

BUG=612500

Review-Url: https://codereview.chromium.org/2147493006
Cr-Commit-Position: refs/heads/master@{#405504}
diff --git a/ipc/ipc_channel.h b/ipc/ipc_channel.h
index 1fc9c6c..33f4d19 100644
--- a/ipc/ipc_channel.h
+++ b/ipc/ipc_channel.h
@@ -13,7 +13,9 @@
 
 #include "base/compiler_specific.h"
 #include "base/files/scoped_file.h"
+#include "base/memory/ref_counted.h"
 #include "base/process/process.h"
+#include "base/single_thread_task_runner.h"
 #include "build/build_config.h"
 #include "ipc/ipc_channel_handle.h"
 #include "ipc/ipc_endpoint.h"
@@ -120,6 +122,11 @@
         const std::string& name,
         mojo::ScopedInterfaceEndpointHandle handle) = 0;
 
+    // Sets the TaskRunner on which to support proxied dispatch for associated
+    // interfaces.
+    virtual void SetProxyTaskRunner(
+        scoped_refptr<base::SingleThreadTaskRunner> task_runner) = 0;
+
     // Template helper to add an interface factory to this channel.
     template <typename Interface>
     using AssociatedInterfaceFactory =
diff --git a/ipc/ipc_channel_mojo.cc b/ipc/ipc_channel_mojo.cc
index 70a12f4..98fde28 100644
--- a/ipc/ipc_channel_mojo.cc
+++ b/ipc/ipc_channel_mojo.cc
@@ -498,4 +498,10 @@
   message_reader_->GetRemoteInterface(name, std::move(handle));
 }
 
+void ChannelMojo::SetProxyTaskRunner(
+    scoped_refptr<base::SingleThreadTaskRunner> task_runner) {
+  DCHECK(bootstrap_);
+  bootstrap_->SetProxyTaskRunner(task_runner);
+}
+
 }  // namespace IPC
diff --git a/ipc/ipc_channel_mojo.h b/ipc/ipc_channel_mojo.h
index c267649..7a35ee7 100644
--- a/ipc/ipc_channel_mojo.h
+++ b/ipc/ipc_channel_mojo.h
@@ -111,6 +111,8 @@
   void GetGenericRemoteAssociatedInterface(
       const std::string& name,
       mojo::ScopedInterfaceEndpointHandle handle) override;
+  void SetProxyTaskRunner(
+      scoped_refptr<base::SingleThreadTaskRunner> task_runner) override;
 
   // ChannelMojo needs to kill its MessagePipeReader in delayed manner
   // because the channel wants to kill these readers during the
diff --git a/ipc/ipc_channel_mojo_unittest.cc b/ipc/ipc_channel_mojo_unittest.cc
index e846c13..0e08516 100644
--- a/ipc/ipc_channel_mojo_unittest.cc
+++ b/ipc/ipc_channel_mojo_unittest.cc
@@ -11,9 +11,12 @@
 #include <utility>
 
 #include "base/base_paths.h"
+#include "base/bind.h"
 #include "base/files/file.h"
 #include "base/files/scoped_temp_dir.h"
 #include "base/location.h"
+#include "base/macros.h"
+#include "base/message_loop/message_loop.h"
 #include "base/path_service.h"
 #include "base/pickle.h"
 #include "base/run_loop.h"
@@ -103,6 +106,7 @@
   void Init(mojo::ScopedMessagePipeHandle handle) {
     handle_ = std::move(handle);
   }
+
   void Connect(IPC::Listener* listener) {
     channel_ = IPC::ChannelMojo::Create(std::move(handle_),
                                         IPC::Channel::MODE_CLIENT, listener);
@@ -126,34 +130,40 @@
   std::unique_ptr<IPC::ChannelMojo> channel_;
 };
 
-class IPCChannelMojoTest : public testing::Test {
+class IPCChannelMojoTestBase : public testing::Test {
  public:
-  IPCChannelMojoTest() {}
-
-  void TearDown() override { base::RunLoop().RunUntilIdle(); }
-
   void InitWithMojo(const std::string& test_client_name) {
     handle_ = helper_.StartChild(test_client_name);
   }
 
+  bool WaitForClientShutdown() { return helper_.WaitForChildTestShutdown(); }
+
+ protected:
+  mojo::ScopedMessagePipeHandle TakeHandle() { return std::move(handle_); }
+
+ private:
+  mojo::ScopedMessagePipeHandle handle_;
+  mojo::edk::test::MultiprocessTestHelper helper_;
+};
+
+class IPCChannelMojoTest : public IPCChannelMojoTestBase {
+ public:
+  void TearDown() override { base::RunLoop().RunUntilIdle(); }
+
   void CreateChannel(IPC::Listener* listener) {
-    channel_ = IPC::ChannelMojo::Create(std::move(handle_),
-                                        IPC::Channel::MODE_SERVER, listener);
+    channel_ = IPC::ChannelMojo::Create(
+        TakeHandle(), IPC::Channel::MODE_SERVER, listener);
   }
 
   bool ConnectChannel() { return channel_->Connect(); }
 
   void DestroyChannel() { channel_.reset(); }
 
-  bool WaitForClientShutdown() { return helper_.WaitForChildTestShutdown(); }
-
   IPC::Sender* sender() { return channel(); }
   IPC::Channel* channel() { return channel_.get(); }
 
  private:
   base::MessageLoop message_loop_;
-  mojo::edk::test::MultiprocessTestHelper helper_;
-  mojo::ScopedMessagePipeHandle handle_;
   std::unique_ptr<IPC::Channel> channel_;
 };
 
@@ -699,7 +709,176 @@
   Close();
 }
 
+class ChannelProxyRunner {
+ public:
+  ChannelProxyRunner(std::unique_ptr<IPC::ChannelFactory> channel_factory)
+      : channel_factory_(std::move(channel_factory)),
+        io_thread_("ChannelProxyRunner IO thread") {
+  }
+
+  void CreateProxy(IPC::Listener* listener) {
+    io_thread_.StartWithOptions(
+        base::Thread::Options(base::MessageLoop::TYPE_IO, 0));
+    proxy_.reset(new IPC::ChannelProxy(listener, io_thread_.task_runner()));
+  }
+  void RunProxy() { proxy_->Init(std::move(channel_factory_), true); }
+
+  IPC::ChannelProxy* proxy() { return proxy_.get(); }
+
+ private:
+  std::unique_ptr<IPC::ChannelFactory> channel_factory_;
+
+  base::Thread io_thread_;
+  std::unique_ptr<IPC::ChannelProxy> proxy_;
+
+  DISALLOW_COPY_AND_ASSIGN(ChannelProxyRunner);
+};
+
+class IPCChannelProxyMojoTest : public IPCChannelMojoTestBase {
+ public:
+  void InitWithMojo(const std::string& client_name) {
+    IPCChannelMojoTestBase::InitWithMojo(client_name);
+    runner_.reset(new ChannelProxyRunner(
+        IPC::ChannelMojo::CreateServerFactory(TakeHandle())));
+  }
+  void CreateProxy(IPC::Listener* listener) { runner_->CreateProxy(listener); }
+  void RunProxy() { runner_->RunProxy(); }
+
+  IPC::ChannelProxy* proxy() { return runner_->proxy(); }
+
+ private:
+  base::MessageLoop message_loop_;
+  std::unique_ptr<ChannelProxyRunner> runner_;
+};
+
+class ListenerWithSimpleProxyAssociatedInterface
+    : public IPC::Listener,
+      public IPC::mojom::SimpleTestDriver {
+ public:
+  static const int kNumMessages;
+
+  ListenerWithSimpleProxyAssociatedInterface() : binding_(this) {}
+
+  ~ListenerWithSimpleProxyAssociatedInterface() override {}
+
+  bool OnMessageReceived(const IPC::Message& message) override {
+    base::PickleIterator iter(message);
+    std::string should_be_expected;
+    EXPECT_TRUE(iter.ReadString(&should_be_expected));
+    EXPECT_EQ(should_be_expected, next_expected_string_);
+    num_messages_received_++;
+    return true;
+  }
+
+  void OnChannelError() override {
+    DCHECK(received_quit_);
+  }
+
+  void RegisterInterfaceFactory(IPC::ChannelProxy* proxy) {
+    proxy->AddAssociatedInterface(
+        base::Bind(&ListenerWithSimpleProxyAssociatedInterface::BindRequest,
+                   base::Unretained(this)));
+  }
+
+  bool received_all_messages() const {
+    return num_messages_received_ == kNumMessages && received_quit_;
+  }
+
+ private:
+  // IPC::mojom::SimpleTestDriver:
+  void ExpectString(const mojo::String& str) override {
+    next_expected_string_ = str;
+  }
+
+  void RequestQuit(const RequestQuitCallback& callback) override {
+    received_quit_ = true;
+    callback.Run();
+    base::MessageLoop::current()->QuitWhenIdle();
+  }
+
+  void BindRequest(IPC::mojom::SimpleTestDriverAssociatedRequest request) {
+    DCHECK(!binding_.is_bound());
+    binding_.Bind(std::move(request));
+  }
+
+  std::string next_expected_string_;
+  int num_messages_received_ = 0;
+  bool received_quit_ = false;
+
+  mojo::AssociatedBinding<IPC::mojom::SimpleTestDriver> binding_;
+};
+
+const int ListenerWithSimpleProxyAssociatedInterface::kNumMessages = 1000;
+
+TEST_F(IPCChannelProxyMojoTest, ProxyThreadAssociatedInterface) {
+  InitWithMojo("ProxyThreadAssociatedInterfaceClient");
+
+  ListenerWithSimpleProxyAssociatedInterface listener;
+  CreateProxy(&listener);
+  listener.RegisterInterfaceFactory(proxy());
+  RunProxy();
+
+  base::RunLoop().Run();
+
+  EXPECT_TRUE(WaitForClientShutdown());
+  EXPECT_TRUE(listener.received_all_messages());
+
+  base::RunLoop().RunUntilIdle();
+}
+
+class ChannelProxyClient {
+ public:
+  void Init(mojo::ScopedMessagePipeHandle handle) {
+    runner_.reset(new ChannelProxyRunner(
+        IPC::ChannelMojo::CreateClientFactory(std::move(handle))));
+  }
+  void CreateProxy(IPC::Listener* listener) { runner_->CreateProxy(listener); }
+  void RunProxy() { runner_->RunProxy(); }
+
+  IPC::ChannelProxy* proxy() { return runner_->proxy(); }
+
+ private:
+  base::MessageLoop message_loop_;
+  std::unique_ptr<ChannelProxyRunner> runner_;
+};
+
+class ListenerThatWaitsForConnect : public IPC::Listener {
+ public:
+  explicit ListenerThatWaitsForConnect(const base::Closure& connect_handler)
+      : connect_handler_(connect_handler) {}
+
+  // IPC::Listener
+  bool OnMessageReceived(const IPC::Message& message) override { return true; }
+  void OnChannelConnected(int32_t) override { connect_handler_.Run(); }
+
+ private:
+  base::Closure connect_handler_;
+};
+
+DEFINE_IPC_CHANNEL_MOJO_TEST_CLIENT(ProxyThreadAssociatedInterfaceClient,
+                                    ChannelProxyClient) {
+  base::RunLoop connect_loop;
+  ListenerThatWaitsForConnect listener(connect_loop.QuitClosure());
+  CreateProxy(&listener);
+  RunProxy();
+  connect_loop.Run();
+
+  // Send a bunch of interleaved messages, alternating between the associated
+  // interface and a legacy IPC::Message.
+  IPC::mojom::SimpleTestDriverAssociatedPtr driver;
+  proxy()->GetRemoteAssociatedInterface(&driver);
+  for (int i = 0; i < ListenerWithSimpleProxyAssociatedInterface::kNumMessages;
+       ++i) {
+    std::string str = base::StringPrintf("Hello! %d", i);
+    driver->ExpectString(str);
+    SendString(proxy(), str);
+  }
+  driver->RequestQuit(base::MessageLoop::QuitWhenIdleClosure());
+  base::RunLoop().Run();
+}
+
 #if defined(OS_POSIX)
+
 class ListenerThatExpectsFile : public IPC::Listener {
  public:
   ListenerThatExpectsFile() : sender_(NULL) {}
diff --git a/ipc/ipc_channel_proxy.cc b/ipc/ipc_channel_proxy.cc
index b342412..5ac51cd 100644
--- a/ipc/ipc_channel_proxy.cc
+++ b/ipc/ipc_channel_proxy.cc
@@ -27,6 +27,17 @@
 
 namespace IPC {
 
+namespace {
+
+void BindAssociatedInterfaceOnTaskRunner(
+    scoped_refptr<base::SingleThreadTaskRunner> task_runner,
+    const ChannelProxy::GenericAssociatedInterfaceFactory& factory,
+    mojo::ScopedInterfaceEndpointHandle handle) {
+  task_runner->PostTask(FROM_HERE, base::Bind(factory, base::Passed(&handle)));
+}
+
+}  // namespace
+
 //------------------------------------------------------------------------------
 
 ChannelProxy::Context::Context(
@@ -147,6 +158,23 @@
 
   for (size_t i = 0; i < filters_.size(); ++i)
     filters_[i]->OnFilterAdded(channel_.get());
+
+  Channel::AssociatedInterfaceSupport* support =
+      channel_->GetAssociatedInterfaceSupport();
+  if (support) {
+    support->SetProxyTaskRunner(listener_task_runner_);
+    for (auto& entry : io_thread_interfaces_)
+      support->AddGenericAssociatedInterface(entry.first, entry.second);
+    for (auto& entry : proxy_thread_interfaces_) {
+      support->AddGenericAssociatedInterface(
+          entry.first, base::Bind(&BindAssociatedInterfaceOnTaskRunner,
+                                  listener_task_runner_, entry.second));
+    }
+  } else {
+    // Sanity check to ensure nobody's expecting to use associated interfaces on
+    // a Channel that doesn't support them.
+    DCHECK(io_thread_interfaces_.empty() && proxy_thread_interfaces_.empty());
+  }
 }
 
 // Called on the IPC::Channel thread
@@ -296,6 +324,15 @@
   if (channel_connected_called_)
     return;
 
+  if (channel_) {
+    Channel::AssociatedInterfaceSupport* associated_interface_support =
+        channel_->GetAssociatedInterfaceSupport();
+    if (associated_interface_support) {
+      channel_associated_group_.reset(new mojo::AssociatedGroup(
+          *associated_interface_support->GetAssociatedGroup()));
+    }
+  }
+
   channel_connected_called_ = true;
   if (listener_)
     listener_->OnChannelConnected(peer_pid_);
@@ -341,6 +378,19 @@
   return channel_send_thread_safe_;
 }
 
+// Called on the IPC::Channel thread
+void ChannelProxy::Context::GetRemoteAssociatedInterface(
+    const std::string& name,
+    mojo::ScopedInterfaceEndpointHandle handle) {
+  if (!channel_)
+    return;
+  Channel::AssociatedInterfaceSupport* associated_interface_support =
+      channel_->GetAssociatedInterfaceSupport();
+  DCHECK(associated_interface_support);
+  associated_interface_support->GetGenericRemoteAssociatedInterface(
+      name, std::move(handle));
+}
+
 //-----------------------------------------------------------------------------
 
 // static
@@ -479,6 +529,34 @@
                             base::RetainedRef(filter)));
 }
 
+void ChannelProxy::AddGenericAssociatedInterfaceForIOThread(
+    const std::string& name,
+    const GenericAssociatedInterfaceFactory& factory) {
+  DCHECK(CalledOnValidThread());
+  DCHECK(!did_init_);
+  context_->io_thread_interfaces_.insert({ name, factory });
+}
+
+void ChannelProxy::AddGenericAssociatedInterface(
+    const std::string& name,
+    const GenericAssociatedInterfaceFactory& factory) {
+  DCHECK(CalledOnValidThread());
+  DCHECK(!did_init_);
+  context_->proxy_thread_interfaces_.insert({ name, factory });
+}
+
+mojo::AssociatedGroup* ChannelProxy::GetAssociatedGroup() {
+  return context_->channel_associated_group_.get();
+}
+
+void ChannelProxy::GetGenericRemoteAssociatedInterface(
+    const std::string& name,
+    mojo::ScopedInterfaceEndpointHandle handle) {
+  context_->ipc_task_runner()->PostTask(
+      FROM_HERE, base::Bind(&Context::GetRemoteAssociatedInterface,
+                            context_.get(), name, base::Passed(&handle)));
+}
+
 void ChannelProxy::ClearIPCTaskRunner() {
   DCHECK(CalledOnValidThread());
 
diff --git a/ipc/ipc_channel_proxy.h b/ipc/ipc_channel_proxy.h
index 0c93233..eeb0468 100644
--- a/ipc/ipc_channel_proxy.h
+++ b/ipc/ipc_channel_proxy.h
@@ -7,9 +7,12 @@
 
 #include <stdint.h>
 
+#include <map>
 #include <memory>
+#include <string>
 #include <vector>
 
+#include "base/callback.h"
 #include "base/memory/ref_counted.h"
 #include "base/synchronization/lock.h"
 #include "base/threading/non_thread_safe.h"
@@ -19,6 +22,9 @@
 #include "ipc/ipc_endpoint.h"
 #include "ipc/ipc_listener.h"
 #include "ipc/ipc_sender.h"
+#include "mojo/public/cpp/bindings/associated_group.h"
+#include "mojo/public/cpp/bindings/associated_interface_request.h"
+#include "mojo/public/cpp/bindings/scoped_interface_endpoint_handle.h"
 
 namespace base {
 class SingleThreadTaskRunner;
@@ -139,6 +145,69 @@
   void AddFilter(MessageFilter* filter);
   void RemoveFilter(MessageFilter* filter);
 
+  using GenericAssociatedInterfaceFactory =
+      base::Callback<void(mojo::ScopedInterfaceEndpointHandle)>;
+
+  // Adds a generic associated interface factory to bind incoming interface
+  // requests directly on the IO thread. MUST be called before Init().
+  void AddGenericAssociatedInterfaceForIOThread(
+      const std::string& name,
+      const GenericAssociatedInterfaceFactory& factory);
+
+  // Adds a generic associated interface factory to bind incoming interface
+  // requests on the ChannelProxy's thread. MUST be called before Init().
+  void AddGenericAssociatedInterface(
+      const std::string& name,
+      const GenericAssociatedInterfaceFactory& factory);
+
+  template <typename Interface>
+  using AssociatedInterfaceFactory =
+      base::Callback<void(mojo::AssociatedInterfaceRequest<Interface>)>;
+
+  // Helper to bind an IO-thread associated interface factory, inferring the
+  // interface name from the callback argument's type. MUST be called before
+  // Init().
+  template <typename Interface>
+  void AddAssociatedInterfaceForIOThread(
+      const AssociatedInterfaceFactory<Interface>& factory) {
+    AddGenericAssociatedInterfaceForIOThread(
+        Interface::Name_,
+        base::Bind(&ChannelProxy::BindAssociatedInterfaceRequest<Interface>,
+                   factory));
+  }
+
+  // Helper to bind a ChannelProxy-thread associated interface factory,
+  // inferring the interface name from the callback argument's type. MUST be
+  // called before Init().
+  template <typename Interface>
+  void AddAssociatedInterface(
+      const AssociatedInterfaceFactory<Interface>& factory) {
+    AddGenericAssociatedInterface(
+        Interface::Name_,
+        base::Bind(&ChannelProxy::BindAssociatedInterfaceRequest<Interface>,
+                   factory));
+  }
+
+  // Gets the AssociatedGroup used to create new associated endpoints on this
+  // ChannelProxy. This must only be called after the listener's
+  // OnChannelConnected is called.
+  mojo::AssociatedGroup* GetAssociatedGroup();
+
+  // Requests an associated interface from the remote endpoint.
+  void GetGenericRemoteAssociatedInterface(
+      const std::string& name,
+      mojo::ScopedInterfaceEndpointHandle handle);
+
+  // Template helper to request associated interfaces from the remote endpoint.
+  // Must only be called after the listener's OnChannelConnected is called.
+  template <typename Interface>
+  void GetRemoteAssociatedInterface(
+      mojo::AssociatedInterfacePtr<Interface>* proxy) {
+    mojo::AssociatedInterfaceRequest<Interface> request =
+        mojo::GetProxy(proxy, GetAssociatedGroup());
+    GetGenericRemoteAssociatedInterface(Interface::Name_, request.PassHandle());
+  }
+
 #if defined(ENABLE_IPC_FUZZER)
   void set_outgoing_message_filter(OutgoingMessageFilter* filter) {
     outgoing_message_filter_ = filter;
@@ -186,6 +255,11 @@
     // Indicates if the underlying channel's Send is thread-safe.
     bool IsChannelSendThreadSafe() const;
 
+    // Requests a remote associated interface on the IPC thread.
+    void GetRemoteAssociatedInterface(
+        const std::string& name,
+        mojo::ScopedInterfaceEndpointHandle handle);
+
    protected:
     friend class base::RefCountedThreadSafe<Context>;
     ~Context() override;
@@ -277,6 +351,16 @@
     // Whether this channel is used as an endpoint for sending and receiving
     // brokerable attachment messages to/from the broker process.
     bool attachment_broker_endpoint_;
+
+    // Modified only on the listener's thread before Init() is called.
+    std::map<std::string, GenericAssociatedInterfaceFactory>
+        io_thread_interfaces_;
+    std::map<std::string, GenericAssociatedInterfaceFactory>
+        proxy_thread_interfaces_;
+
+    // Valid and constant any time after the ChannelProxy's Listener receives
+    // OnChannelConnected on its own thread.
+    std::unique_ptr<mojo::AssociatedGroup> channel_associated_group_;
   };
 
   Context* context() { return context_.get(); }
@@ -293,6 +377,15 @@
  private:
   friend class IpcSecurityTestUtil;
 
+  template <typename Interface>
+  static void BindAssociatedInterfaceRequest(
+      const AssociatedInterfaceFactory<Interface>& factory,
+      mojo::ScopedInterfaceEndpointHandle handle) {
+    mojo::AssociatedInterfaceRequest<Interface> request;
+    request.Bind(std::move(handle));
+    factory.Run(std::move(request));
+  }
+
   // Always called once immediately after Init.
   virtual void OnChannelInit();
 
diff --git a/ipc/ipc_mojo_bootstrap.cc b/ipc/ipc_mojo_bootstrap.cc
index 425f7948..fc39d0d9 100644
--- a/ipc/ipc_mojo_bootstrap.cc
+++ b/ipc/ipc_mojo_bootstrap.cc
@@ -121,7 +121,7 @@
     if (!is_local) {
       DCHECK(ContainsKey(endpoints_, id));
       DCHECK(!mojo::IsMasterInterfaceId(id));
-      control_message_proxy_.NotifyEndpointClosedBeforeSent(id);
+      NotifyEndpointClosedBeforeSent(id);
       return;
     }
 
@@ -132,7 +132,7 @@
     MarkClosedAndMaybeRemove(endpoint);
 
     if (!mojo::IsMasterInterfaceId(id))
-      control_message_proxy_.NotifyPeerEndpointClosed(id);
+      NotifyPeerEndpointClosed(id);
   }
 
   mojo::InterfaceEndpointController* AttachEndpointClient(
@@ -392,6 +392,30 @@
       endpoints_.erase(endpoint->id());
   }
 
+  void NotifyPeerEndpointClosed(mojo::InterfaceId id) {
+    if (task_runner_->BelongsToCurrentThread()) {
+      if (connector_.is_valid())
+        control_message_proxy_.NotifyPeerEndpointClosed(id);
+    } else {
+      task_runner_->PostTask(
+          FROM_HERE,
+          base::Bind(&ChannelAssociatedGroupController
+              ::NotifyPeerEndpointClosed, this, id));
+    }
+  }
+
+  void NotifyEndpointClosedBeforeSent(mojo::InterfaceId id) {
+    if (task_runner_->BelongsToCurrentThread()) {
+      if (connector_.is_valid())
+        control_message_proxy_.NotifyEndpointClosedBeforeSent(id);
+    } else {
+      task_runner_->PostTask(
+          FROM_HERE,
+          base::Bind(&ChannelAssociatedGroupController
+              ::NotifyEndpointClosedBeforeSent, this, id));
+    }
+  }
+
   Endpoint* FindOrInsertEndpoint(mojo::InterfaceId id, bool* inserted) {
     lock_.AssertAcquired();
     DCHECK(!inserted || !*inserted);
@@ -411,26 +435,15 @@
   bool Accept(mojo::Message* message) override {
     DCHECK(thread_checker_.CalledOnValidThread());
 
-    if (mojo::PipeControlMessageHandler::IsPipeControlMessage(message)) {
-      if (!control_message_handler_.Accept(message))
-        RaiseError();
-      return true;
-    }
+    if (mojo::PipeControlMessageHandler::IsPipeControlMessage(message))
+      return control_message_handler_.Accept(message);
 
     mojo::InterfaceId id = message->interface_id();
     DCHECK(mojo::IsValidInterfaceId(id));
 
     base::AutoLock locker(lock_);
-    bool inserted = false;
-    Endpoint* endpoint = FindOrInsertEndpoint(id, &inserted);
-    if (inserted) {
-      MarkClosedAndMaybeRemove(endpoint);
-      if (!mojo::IsMasterInterfaceId(id))
-        control_message_proxy_.NotifyPeerEndpointClosed(id);
-      return true;
-    }
-
-    if (endpoint->closed())
+    Endpoint* endpoint = GetEndpointForDispatch(id);
+    if (!endpoint)
       return true;
 
     mojo::InterfaceEndpointClient* client = endpoint->client();
@@ -442,7 +455,6 @@
       // If the client is not yet bound, it must be bound by the time this task
       // runs or else it's programmer error.
       DCHECK(proxy_task_runner_);
-      CHECK(false);
       std::unique_ptr<mojo::Message> passed_message(new mojo::Message);
       message->MoveTo(passed_message.get());
       proxy_task_runner_->PostTask(
@@ -456,23 +468,56 @@
     // If it's happening, it's a bug.
     DCHECK(!message->has_flag(mojo::Message::kFlagIsSync));
 
-    bool result = false;
-    {
-      base::AutoUnlock unlocker(lock_);
-      result = client->HandleIncomingMessage(message);
-    }
-
-    if (!result)
-      RaiseError();
-
-    return true;
+    base::AutoUnlock unlocker(lock_);
+    return client->HandleIncomingMessage(message);
   }
 
   void AcceptOnProxyThread(std::unique_ptr<mojo::Message> message) {
     DCHECK(proxy_task_runner_->BelongsToCurrentThread());
 
-    // TODO(rockot): Implement this.
-    NOTREACHED();
+    mojo::InterfaceId id = message->interface_id();
+    DCHECK(mojo::IsValidInterfaceId(id) && !mojo::IsMasterInterfaceId(id));
+
+    base::AutoLock locker(lock_);
+    Endpoint* endpoint = GetEndpointForDispatch(id);
+    if (!endpoint)
+      return;
+
+    mojo::InterfaceEndpointClient* client = endpoint->client();
+    if (!client)
+      return;
+
+    DCHECK(endpoint->task_runner()->BelongsToCurrentThread());
+
+    // TODO(rockot): Implement sync dispatch. For now, sync messages are
+    // unsupported here.
+    DCHECK(!message->has_flag(mojo::Message::kFlagIsSync));
+
+    bool result = false;
+    {
+      base::AutoUnlock unlocker(lock_);
+      result = client->HandleIncomingMessage(message.get());
+    }
+
+    if (!result)
+      RaiseError();
+  }
+
+  Endpoint* GetEndpointForDispatch(mojo::InterfaceId id) {
+    lock_.AssertAcquired();
+    bool inserted = false;
+    Endpoint* endpoint = FindOrInsertEndpoint(id, &inserted);
+    if (inserted) {
+      MarkClosedAndMaybeRemove(endpoint);
+      if (!mojo::IsMasterInterfaceId(id))
+        NotifyPeerEndpointClosed(id);
+      return nullptr;
+    }
+
+    if (endpoint->closed())
+      return nullptr;
+
+    return endpoint;
   }
 
   // mojo::PipeControlMessageHandlerDelegate:
@@ -561,6 +606,11 @@
     return controller_->associated_group();
   }
 
+  ChannelAssociatedGroupController* controller() {
+    DCHECK(controller_);
+    return controller_.get();
+  }
+
   mojom::Bootstrap* operator->() {
     DCHECK(proxy_);
     return proxy_.get();
@@ -596,6 +646,11 @@
     return controller_->associated_group();
   }
 
+  ChannelAssociatedGroupController* controller() {
+    DCHECK(controller_);
+    return controller_.get();
+  }
+
   void Bind(mojo::ScopedMessagePipeHandle handle) {
     DCHECK(!controller_);
     controller_ =
@@ -626,10 +681,16 @@
  private:
   // MojoBootstrap implementation.
   void Connect() override;
+
   mojo::AssociatedGroup* GetAssociatedGroup() override {
     return bootstrap_.associated_group();
   }
 
+  void SetProxyTaskRunner(
+      scoped_refptr<base::SingleThreadTaskRunner> task_runner) override {
+    bootstrap_.controller()->SetProxyTaskRunner(task_runner);
+  }
+
   void OnInitDone(int32_t peer_pid);
 
   BootstrapMasterProxy bootstrap_;
@@ -688,10 +749,16 @@
  private:
   // MojoBootstrap implementation.
   void Connect() override;
+
   mojo::AssociatedGroup* GetAssociatedGroup() override {
     return binding_.associated_group();
   }
 
+  void SetProxyTaskRunner(
+      scoped_refptr<base::SingleThreadTaskRunner> task_runner) override {
+    binding_.controller()->SetProxyTaskRunner(task_runner);
+  }
+
   // mojom::Bootstrap implementation.
   void Init(mojom::ChannelAssociatedRequest receive_channel,
             mojom::ChannelAssociatedPtrInfo send_channel,
diff --git a/ipc/ipc_mojo_bootstrap.h b/ipc/ipc_mojo_bootstrap.h
index 1d632d3..b9df4083 100644
--- a/ipc/ipc_mojo_bootstrap.h
+++ b/ipc/ipc_mojo_bootstrap.h
@@ -10,7 +10,9 @@
 #include <memory>
 
 #include "base/macros.h"
+#include "base/memory/ref_counted.h"
 #include "base/process/process_handle.h"
+#include "base/single_thread_task_runner.h"
 #include "build/build_config.h"
 #include "ipc/ipc.mojom.h"
 #include "ipc/ipc_channel.h"
@@ -54,6 +56,9 @@
 
   virtual mojo::AssociatedGroup* GetAssociatedGroup() = 0;
 
+  virtual void SetProxyTaskRunner(
+      scoped_refptr<base::SingleThreadTaskRunner> task_runner) = 0;
+
   // GetSelfPID returns our PID.
   base::ProcessId GetSelfPID() const;