[go: nahoru, domu]

Implement the AggregatableReportAssembler

This class provides an interface for assembling an aggregatable report.
It is therefore responsible for requesting the appropriate public key,
calling into CreateFromRequestAndPublicKeys() and returning the
constructed AggregatableReport.

This class will be owned by the AggregatableReportManager instance, once
implemented.

Bug: 1217821
Change-Id: I910d8233c92e6ce68b446391234f8f0788c49f21
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/3115352
Reviewed-by: John Delaney <johnidel@chromium.org>
Reviewed-by: Nan Lin <linnan@chromium.org>
Commit-Queue: Alex Turner <alexmt@chromium.org>
Cr-Commit-Position: refs/heads/main@{#928252}
diff --git a/content/browser/BUILD.gn b/content/browser/BUILD.gn
index 00d4d12..1c70b7d 100644
--- a/content/browser/BUILD.gn
+++ b/content/browser/BUILD.gn
@@ -356,6 +356,8 @@
     "after_startup_task_utils.h",
     "aggregation_service/aggregatable_report.cc",
     "aggregation_service/aggregatable_report.h",
+    "aggregation_service/aggregatable_report_assembler.cc",
+    "aggregation_service/aggregatable_report_assembler.h",
     "aggregation_service/aggregatable_report_manager.h",
     "aggregation_service/aggregatable_report_sender.cc",
     "aggregation_service/aggregatable_report_sender.h",
diff --git a/content/browser/aggregation_service/aggregatable_report_assembler.cc b/content/browser/aggregation_service/aggregatable_report_assembler.cc
new file mode 100644
index 0000000..9087936f
--- /dev/null
+++ b/content/browser/aggregation_service/aggregatable_report_assembler.cc
@@ -0,0 +1,173 @@
+// Copyright 2021 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "content/browser/aggregation_service/aggregatable_report_assembler.h"
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "base/bind.h"
+#include "base/check.h"
+#include "base/check_op.h"
+#include "base/containers/contains.h"
+#include "base/memory/ptr_util.h"
+#include "base/ranges/algorithm.h"
+#include "base/time/default_clock.h"
+#include "content/browser/aggregation_service/aggregatable_report.h"
+#include "content/browser/aggregation_service/aggregatable_report_manager.h"
+#include "content/browser/aggregation_service/aggregation_service_key_fetcher.h"
+#include "content/browser/aggregation_service/aggregation_service_network_fetcher_impl.h"
+#include "content/browser/aggregation_service/public_key.h"
+#include "content/public/browser/storage_partition.h"
+#include "third_party/abseil-cpp/absl/types/optional.h"
+#include "url/origin.h"
+
+namespace content {
+
+AggregatableReportAssembler::AggregatableReportAssembler(
+    AggregatableReportManager* manager,
+    StoragePartition* storage_partition)
+    : AggregatableReportAssembler(
+          std::make_unique<AggregationServiceKeyFetcher>(
+              manager,
+              std::make_unique<AggregationServiceNetworkFetcherImpl>(
+                  base::DefaultClock::GetInstance(),
+                  storage_partition)),
+          std::make_unique<AggregatableReport::Provider>()) {}
+
+AggregatableReportAssembler::AggregatableReportAssembler(
+    std::unique_ptr<AggregationServiceKeyFetcher> fetcher,
+    std::unique_ptr<AggregatableReport::Provider> report_provider)
+    : fetcher_(std::move(fetcher)),
+      report_provider_(std::move(report_provider)) {}
+
+AggregatableReportAssembler::~AggregatableReportAssembler() = default;
+
+AggregatableReportAssembler::PendingRequest::PendingRequest(
+    AggregatableReportRequest report_request,
+    AggregatableReportAssembler::AssemblyCallback callback,
+    size_t num_processing_origins)
+    : report_request(std::move(report_request)),
+      callback(std::move(callback)),
+      processing_origin_keys(num_processing_origins) {
+  DCHECK(this->callback);
+}
+
+AggregatableReportAssembler::PendingRequest::PendingRequest(
+    AggregatableReportAssembler::PendingRequest&& other) = default;
+
+AggregatableReportAssembler::PendingRequest&
+AggregatableReportAssembler::PendingRequest::operator=(
+    AggregatableReportAssembler::PendingRequest&& other) = default;
+
+AggregatableReportAssembler::PendingRequest::~PendingRequest() = default;
+
+// static
+std::unique_ptr<AggregatableReportAssembler>
+AggregatableReportAssembler::CreateForTesting(
+    std::unique_ptr<AggregationServiceKeyFetcher> fetcher,
+    std::unique_ptr<AggregatableReport::Provider> report_provider) {
+  return base::WrapUnique(new AggregatableReportAssembler(
+      std::move(fetcher), std::move(report_provider)));
+}
+
+void AggregatableReportAssembler::AssembleReport(
+    AggregatableReportRequest report_request,
+    AssemblyCallback callback) {
+  DCHECK_EQ(report_request.processing_origins().size(),
+            AggregatableReport::kNumberOfProcessingOrigins);
+  DCHECK(base::ranges::is_sorted(report_request.processing_origins()));
+
+  const AggregationServicePayloadContents& contents =
+      report_request.payload_contents();
+
+  // Currently, these should be the only possible enum values.
+  DCHECK_EQ(contents.operation,
+            AggregationServicePayloadContents::Operation::kCountValueHistogram);
+  DCHECK_EQ(contents.processing_type,
+            AggregationServicePayloadContents::ProcessingType::kTwoParty);
+
+  if (pending_requests_.size() >= kMaxSimultaneousRequests) {
+    std::move(callback).Run(absl::nullopt,
+                            AssemblyStatus::kTooManySimultaneousRequests);
+    return;
+  }
+
+  int64_t id = unique_id_counter_++;
+  DCHECK(!base::Contains(pending_requests_, id));
+
+  const PendingRequest& pending_request =
+      pending_requests_
+          .emplace(id, PendingRequest(
+                           std::move(report_request), std::move(callback),
+                           /*num_processing_origins=*/
+                           AggregatableReport::kNumberOfProcessingOrigins))
+          .first->second;
+
+  for (size_t i = 0; i < AggregatableReport::kNumberOfProcessingOrigins; ++i) {
+    // `fetcher_` is owned by `this`, so `base::Unretained()` is safe.
+    fetcher_->GetPublicKey(
+        pending_request.report_request.processing_origins()[i],
+        base::BindOnce(&AggregatableReportAssembler::OnPublicKeyFetched,
+                       base::Unretained(this), /*report_id=*/id,
+                       /*processing_origin_index=*/i));
+  }
+}
+
+void AggregatableReportAssembler::OnPublicKeyFetched(
+    int64_t report_id,
+    size_t processing_origin_index,
+    absl::optional<PublicKey> key,
+    AggregationServiceKeyFetcher::PublicKeyFetchStatus status) {
+  DCHECK_EQ(key.has_value(),
+            status == AggregationServiceKeyFetcher::PublicKeyFetchStatus::kOk);
+  auto pending_request_it = pending_requests_.find(report_id);
+
+  // This should only be possible if we have already thrown an error.
+  if (pending_request_it == pending_requests_.end())
+    return;
+
+  PendingRequest& pending_request = pending_request_it->second;
+
+  // TODO(crbug.com/1254792): Consider implementing some retry logic.
+  if (!key.has_value()) {
+    std::move(pending_request.callback)
+        .Run(absl::nullopt, AssemblyStatus::kPublicKeyFetchFailed);
+    pending_requests_.erase(pending_request_it);
+    return;
+  }
+
+  ++pending_request.num_returned_key_fetches;
+  pending_request.processing_origin_keys[processing_origin_index] =
+      std::move(key);
+
+  if (pending_request.num_returned_key_fetches ==
+      AggregatableReport::kNumberOfProcessingOrigins) {
+    OnAllPublicKeysFetched(report_id, pending_request);
+  }
+}
+
+void AggregatableReportAssembler::OnAllPublicKeysFetched(
+    int64_t report_id,
+    PendingRequest& pending_request) {
+  std::vector<PublicKey> public_keys;
+  for (absl::optional<PublicKey> elem :
+       pending_request.processing_origin_keys) {
+    DCHECK(elem.has_value());
+    public_keys.push_back(std::move(elem.value()));
+  }
+
+  absl::optional<AggregatableReport> assembled_report =
+      report_provider_->CreateFromRequestAndPublicKeys(
+          std::move(pending_request.report_request), std::move(public_keys));
+  AssemblyStatus assembly_status =
+      assembled_report ? AssemblyStatus::kOk : AssemblyStatus::kAssemblyFailed;
+  std::move(pending_request.callback)
+      .Run(std::move(assembled_report), assembly_status);
+
+  pending_requests_.erase(report_id);
+}
+
+}  // namespace content
diff --git a/content/browser/aggregation_service/aggregatable_report_assembler.h b/content/browser/aggregation_service/aggregatable_report_assembler.h
new file mode 100644
index 0000000..5481746
--- /dev/null
+++ b/content/browser/aggregation_service/aggregatable_report_assembler.h
@@ -0,0 +1,134 @@
+// Copyright 2021 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef CONTENT_BROWSER_AGGREGATION_SERVICE_AGGREGATABLE_REPORT_ASSEMBLER_H_
+#define CONTENT_BROWSER_AGGREGATION_SERVICE_AGGREGATABLE_REPORT_ASSEMBLER_H_
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <array>
+#include <memory>
+#include <vector>
+
+#include "base/callback.h"
+#include "base/containers/flat_map.h"
+#include "content/browser/aggregation_service/aggregatable_report.h"
+#include "content/browser/aggregation_service/aggregation_service_key_fetcher.h"
+#include "content/browser/aggregation_service/public_key.h"
+#include "content/common/content_export.h"
+#include "third_party/abseil-cpp/absl/types/optional.h"
+#include "url/origin.h"
+
+namespace content {
+
+class AggregatableReportManager;
+class StoragePartition;
+
+// This class provides an interface for assembling an aggregatable report. It is
+// therefore responsible for taking a request, identifying and requesting the
+// appropriate public keys, and generating and returning the AggregatableReport.
+class CONTENT_EXPORT AggregatableReportAssembler {
+ public:
+  enum class AssemblyStatus {
+    kOk,
+
+    // The attempt to fetch a public key failed.
+    kPublicKeyFetchFailed,
+
+    // An internal error occurred while attempting to construct the report.
+    kAssemblyFailed,
+
+    // The limit on the number of simultenous requests has been reached.
+    kTooManySimultaneousRequests,
+    kMaxValue = kTooManySimultaneousRequests,
+  };
+
+  using AssemblyCallback =
+      base::OnceCallback<void(absl::optional<AggregatableReport>,
+                              AssemblyStatus)>;
+
+  // While we shouldn't hit these limits in typical usage, we protect against
+  // the possibility of unbounded memory growth
+  static constexpr size_t kMaxSimultaneousRequests = 1000;
+
+  explicit AggregatableReportAssembler(AggregatableReportManager* manager,
+                                       StoragePartition* storage_partition);
+  // Not copyable or movable.
+  AggregatableReportAssembler(const AggregatableReportAssembler& other) =
+      delete;
+  AggregatableReportAssembler& operator=(
+      const AggregatableReportAssembler& other) = delete;
+  virtual ~AggregatableReportAssembler();
+
+  static std::unique_ptr<AggregatableReportAssembler> CreateForTesting(
+      std::unique_ptr<AggregationServiceKeyFetcher> fetcher,
+      std::unique_ptr<AggregatableReport::Provider> report_provider);
+
+  // Fetches the necessary public keys and uses it to construct an
+  // AggregatableReport from the information in `report_request`. See the
+  // AggregatableReport documentation for more detail on the returned report.
+  void AssembleReport(AggregatableReportRequest report_request,
+                      AssemblyCallback callback);
+
+ private:
+  // Represents a request to assemble a report that has not completed.
+  struct PendingRequest {
+    PendingRequest(AggregatableReportRequest report_request,
+                   AssemblyCallback callback,
+                   size_t num_processing_origins);
+    // Move-only.
+    PendingRequest(PendingRequest&& other);
+    PendingRequest& operator=(PendingRequest&& other);
+    ~PendingRequest();
+
+    AggregatableReportRequest report_request;
+    AssemblyCallback callback;
+
+    // How many key fetches for this request have returned, including errors.
+    size_t num_returned_key_fetches = 0;
+
+    // The PublicKey returned for each key fetch request. Indices correspond to
+    // the ordering of `report_request.processing_origins`. Each element is
+    // `absl::nullopt` if that key fetch either has not yet returned or has
+    // returned an error.
+    std::vector<absl::optional<PublicKey>> processing_origin_keys;
+  };
+
+  AggregatableReportAssembler(
+      std::unique_ptr<AggregationServiceKeyFetcher> fetcher,
+      std::unique_ptr<AggregatableReport::Provider> report_provider);
+
+  // Called when a result is returned from the key fetcher. Handles throwing
+  // errors on a failed fetch, waiting for both results to return and calling
+  // into `OnBothPublicKeysFetched()` when appropriate.
+  // `processing_origin_index` is an index into the corresponding
+  // AggregatableReportRequest's `processing_origins` vector, indicating which
+  // origin this fetch is for.
+  void OnPublicKeyFetched(
+      int64_t report_id,
+      size_t processing_origin_index,
+      absl::optional<PublicKey> key,
+      AggregationServiceKeyFetcher::PublicKeyFetchStatus status);
+
+  // Call when all results have been returned from the key fetcher. Handles
+  // calling into `AssembleReportUsingKeys()` when appropriate and returning
+  // any assembled report or throwing an error if assembly fails.
+  void OnAllPublicKeysFetched(int64_t report_id,
+                              PendingRequest& pending_request);
+
+  // Keyed by a token for easier lookup.
+  base::flat_map<int64_t, PendingRequest> pending_requests_;
+
+  // Used to generate unique ids for PendingRequests. These need to be unique
+  // per Assembler for tracking pending requests.
+  int64_t unique_id_counter_ = 0;
+
+  std::unique_ptr<AggregationServiceKeyFetcher> fetcher_;
+  std::unique_ptr<AggregatableReport::Provider> report_provider_;
+};
+
+}  // namespace content
+
+#endif  // CONTENT_BROWSER_AGGREGATION_SERVICE_AGGREGATABLE_REPORT_ASSEMBLER_H_
diff --git a/content/browser/aggregation_service/aggregatable_report_assembler_unittest.cc b/content/browser/aggregation_service/aggregatable_report_assembler_unittest.cc
new file mode 100644
index 0000000..fa0eb55
--- /dev/null
+++ b/content/browser/aggregation_service/aggregatable_report_assembler_unittest.cc
@@ -0,0 +1,375 @@
+// Copyright 2021 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "content/browser/aggregation_service/aggregatable_report_assembler.h"
+
+#include <stddef.h>
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "base/test/bind.h"
+#include "base/time/time.h"
+#include "content/browser/aggregation_service/aggregatable_report.h"
+#include "content/browser/aggregation_service/aggregation_service_key_fetcher.h"
+#include "content/browser/aggregation_service/aggregation_service_test_utils.h"
+#include "content/browser/aggregation_service/public_key.h"
+#include "testing/gtest/include/gtest/gtest.h"
+#include "third_party/abseil-cpp/absl/types/optional.h"
+#include "url/gurl.h"
+#include "url/origin.h"
+
+namespace content {
+
+class AggregatableReportAssemblerTest : public testing::Test {
+ public:
+  AggregatableReportAssemblerTest() = default;
+
+  void SetUp() override {
+    auto fetcher = std::make_unique<TestAggregationServiceKeyFetcher>();
+    auto report_provider = std::make_unique<TestAggregatableReportProvider>();
+
+    fetcher_ = fetcher.get();
+    report_provider_ = report_provider.get();
+    assembler_ = AggregatableReportAssembler::CreateForTesting(
+        std::move(fetcher), std::move(report_provider));
+
+    num_assembly_callbacks_run_ = 0;
+  }
+
+  void AssembleReport(AggregatableReportRequest request) {
+    assembler()->AssembleReport(
+        std::move(request),
+        base::BindLambdaForTesting(
+            [&](absl::optional<AggregatableReport> report,
+                AggregatableReportAssembler::AssemblyStatus status) {
+              last_assembled_report_ = std::move(report);
+              last_assembled_status_ = std::move(status);
+
+              ++num_assembly_callbacks_run_;
+            }));
+  }
+
+  void ResetAssembler() { assembler_.reset(); }
+
+  AggregatableReportAssembler* assembler() { return assembler_.get(); }
+  TestAggregationServiceKeyFetcher* fetcher() { return fetcher_; }
+  TestAggregatableReportProvider* report_provider() { return report_provider_; }
+  int num_assembly_callbacks_run() const { return num_assembly_callbacks_run_; }
+
+  // Should only be called after the report callback has been run.
+  const absl::optional<AggregatableReport>& last_assembled_report() const {
+    EXPECT_GT(num_assembly_callbacks_run_, 0);
+    return last_assembled_report_;
+  }
+  const AggregatableReportAssembler::AssemblyStatus& last_assembled_status()
+      const {
+    EXPECT_GT(num_assembly_callbacks_run_, 0);
+    return last_assembled_status_;
+  }
+
+ private:
+  std::unique_ptr<AggregatableReportAssembler> assembler_;
+
+  // These objects are owned by `assembler_`.
+  TestAggregationServiceKeyFetcher* fetcher_;
+  TestAggregatableReportProvider* report_provider_;
+
+  int num_assembly_callbacks_run_ = 0;
+
+  // The last arguments passed to the AssemblyCallback are saved here.
+  absl::optional<AggregatableReport> last_assembled_report_;
+  AggregatableReportAssembler::AssemblyStatus last_assembled_status_;
+};
+
+TEST_F(AggregatableReportAssemblerTest, BothKeyFetchesFail_ErrorReturned) {
+  AggregatableReportRequest request =
+      aggregation_service::CreateExampleRequest();
+  std::vector<url::Origin> processing_origins = request.processing_origins();
+
+  AssembleReport(std::move(request));
+
+  fetcher()->TriggerPublicKeyResponse(
+      processing_origins[0], /*key=*/absl::nullopt,
+      AggregationServiceKeyFetcher::PublicKeyFetchStatus::
+          kPublicKeyFetchFailed);
+  fetcher()->TriggerPublicKeyResponse(
+      processing_origins[1], /*key=*/absl::nullopt,
+      AggregationServiceKeyFetcher::PublicKeyFetchStatus::
+          kPublicKeyFetchFailed);
+
+  EXPECT_FALSE(fetcher()->HasPendingCallbacks());
+  EXPECT_EQ(num_assembly_callbacks_run(), 1);
+  EXPECT_EQ(report_provider()->num_calls(), 0);
+
+  EXPECT_FALSE(last_assembled_report().has_value());
+  EXPECT_EQ(last_assembled_status(),
+            AggregatableReportAssembler::AssemblyStatus::kPublicKeyFetchFailed);
+}
+
+TEST_F(AggregatableReportAssemblerTest, FirstKeyFetchFails_ErrorReturned) {
+  AggregatableReportRequest request =
+      aggregation_service::CreateExampleRequest();
+  std::vector<url::Origin> processing_origins = request.processing_origins();
+
+  AssembleReport(std::move(request));
+
+  fetcher()->TriggerPublicKeyResponse(
+      processing_origins[0], /*key=*/absl::nullopt,
+      AggregationServiceKeyFetcher::PublicKeyFetchStatus::
+          kPublicKeyFetchFailed);
+  fetcher()->TriggerPublicKeyResponse(
+      processing_origins[1], aggregation_service::GenerateKey().public_key,
+      AggregationServiceKeyFetcher::PublicKeyFetchStatus::kOk);
+
+  EXPECT_FALSE(fetcher()->HasPendingCallbacks());
+  EXPECT_EQ(num_assembly_callbacks_run(), 1);
+  EXPECT_EQ(report_provider()->num_calls(), 0);
+
+  EXPECT_FALSE(last_assembled_report().has_value());
+  EXPECT_EQ(last_assembled_status(),
+            AggregatableReportAssembler::AssemblyStatus::kPublicKeyFetchFailed);
+}
+
+TEST_F(AggregatableReportAssemblerTest, SecondKeyFetchFails_ErrorReturned) {
+  AggregatableReportRequest request =
+      aggregation_service::CreateExampleRequest();
+  std::vector<url::Origin> processing_origins = request.processing_origins();
+
+  AssembleReport(std::move(request));
+
+  fetcher()->TriggerPublicKeyResponse(
+      processing_origins[0], aggregation_service::GenerateKey().public_key,
+      AggregationServiceKeyFetcher::PublicKeyFetchStatus::kOk);
+  fetcher()->TriggerPublicKeyResponse(
+      processing_origins[1], /*key=*/absl::nullopt,
+      AggregationServiceKeyFetcher::PublicKeyFetchStatus::
+          kPublicKeyFetchFailed);
+
+  EXPECT_FALSE(fetcher()->HasPendingCallbacks());
+  EXPECT_EQ(num_assembly_callbacks_run(), 1);
+  EXPECT_EQ(report_provider()->num_calls(), 0);
+
+  EXPECT_FALSE(last_assembled_report().has_value());
+  EXPECT_EQ(last_assembled_status(),
+            AggregatableReportAssembler::AssemblyStatus::kPublicKeyFetchFailed);
+}
+
+TEST_F(AggregatableReportAssemblerTest,
+       BothKeyFetchesSucceed_ValidReportReturned) {
+  AggregatableReportRequest request =
+      aggregation_service::CreateExampleRequest();
+
+  std::vector<url::Origin> processing_origins = request.processing_origins();
+  std::vector<PublicKey> public_keys = {
+      aggregation_service::GenerateKey("id123").public_key,
+      aggregation_service::GenerateKey("456abc").public_key};
+
+  absl::optional<AggregatableReport> report =
+      AggregatableReport::Provider().CreateFromRequestAndPublicKeys(
+          aggregation_service::CloneReportRequest(request), public_keys);
+  ASSERT_TRUE(report.has_value());
+  report_provider()->set_report_to_return(std::move(report.value()));
+
+  AssembleReport(aggregation_service::CloneReportRequest(request));
+
+  fetcher()->TriggerPublicKeyResponse(
+      processing_origins[0], public_keys[0],
+      AggregationServiceKeyFetcher::PublicKeyFetchStatus::kOk);
+  fetcher()->TriggerPublicKeyResponse(
+      processing_origins[1], public_keys[1],
+      AggregationServiceKeyFetcher::PublicKeyFetchStatus::kOk);
+
+  EXPECT_FALSE(fetcher()->HasPendingCallbacks());
+  EXPECT_EQ(num_assembly_callbacks_run(), 1);
+
+  EXPECT_EQ(report_provider()->num_calls(), 1);
+  EXPECT_TRUE(aggregation_service::ReportRequestsEqual(
+      report_provider()->PreviousRequest(), request));
+  EXPECT_TRUE(aggregation_service::PublicKeysEqual(
+      report_provider()->PreviousPublicKeys(), public_keys));
+
+  EXPECT_TRUE(last_assembled_report().has_value());
+  EXPECT_EQ(last_assembled_status(),
+            AggregatableReportAssembler::AssemblyStatus::kOk);
+}
+
+TEST_F(AggregatableReportAssemblerTest,
+       KeyFetchesReturnInSwappedOrder_ValidReportReturned) {
+  AggregatableReportRequest request =
+      aggregation_service::CreateExampleRequest();
+
+  std::vector<url::Origin> processing_origins = request.processing_origins();
+  std::vector<PublicKey> public_keys = {
+      aggregation_service::GenerateKey("id123").public_key,
+      aggregation_service::GenerateKey("456abc").public_key};
+
+  absl::optional<AggregatableReport> report =
+      AggregatableReport::Provider().CreateFromRequestAndPublicKeys(
+          aggregation_service::CloneReportRequest(request), public_keys);
+  ASSERT_TRUE(report.has_value());
+  report_provider()->set_report_to_return(std::move(report.value()));
+
+  AssembleReport(aggregation_service::CloneReportRequest(request));
+
+  // Swap order of responses
+  fetcher()->TriggerPublicKeyResponse(
+      processing_origins[1], public_keys[1],
+      AggregationServiceKeyFetcher::PublicKeyFetchStatus::kOk);
+  fetcher()->TriggerPublicKeyResponse(
+      processing_origins[0], public_keys[0],
+      AggregationServiceKeyFetcher::PublicKeyFetchStatus::kOk);
+
+  EXPECT_FALSE(fetcher()->HasPendingCallbacks());
+  EXPECT_EQ(num_assembly_callbacks_run(), 1);
+
+  EXPECT_EQ(report_provider()->num_calls(), 1);
+  EXPECT_TRUE(aggregation_service::ReportRequestsEqual(
+      report_provider()->PreviousRequest(), request));
+  EXPECT_TRUE(aggregation_service::PublicKeysEqual(
+      report_provider()->PreviousPublicKeys(), public_keys));
+
+  EXPECT_TRUE(last_assembled_report().has_value());
+  EXPECT_EQ(last_assembled_status(),
+            AggregatableReportAssembler::AssemblyStatus::kOk);
+}
+
+TEST_F(AggregatableReportAssemblerTest,
+       BothProcessingOriginsAreIdentical_ValidReportReturned) {
+  AggregatableReportRequest starter_request =
+      aggregation_service::CreateExampleRequest();
+
+  url::Origin processing_origin = starter_request.processing_origins()[0];
+
+  // Set second processing origin to match the first and create a new request.
+  absl::optional<AggregatableReportRequest> request =
+      AggregatableReportRequest::Create({processing_origin, processing_origin},
+                                        starter_request.payload_contents(),
+                                        starter_request.shared_info());
+
+  ASSERT_TRUE(request.has_value());
+
+  PublicKey public_key = aggregation_service::GenerateKey("id123").public_key;
+
+  absl::optional<AggregatableReport> report =
+      AggregatableReport::Provider().CreateFromRequestAndPublicKeys(
+          aggregation_service::CloneReportRequest(request.value()),
+          /*public_keys=*/{public_key, public_key});
+  ASSERT_TRUE(report.has_value());
+  report_provider()->set_report_to_return(std::move(report.value()));
+
+  AssembleReport(aggregation_service::CloneReportRequest(request.value()));
+
+  fetcher()->TriggerPublicKeyResponse(
+      processing_origin, public_key,
+      AggregationServiceKeyFetcher::PublicKeyFetchStatus::kOk);
+
+  EXPECT_FALSE(fetcher()->HasPendingCallbacks());
+  EXPECT_EQ(num_assembly_callbacks_run(), 1);
+
+  EXPECT_EQ(report_provider()->num_calls(), 1);
+  EXPECT_TRUE(aggregation_service::ReportRequestsEqual(
+      report_provider()->PreviousRequest(), request.value()));
+  EXPECT_TRUE(aggregation_service::PublicKeysEqual(
+      report_provider()->PreviousPublicKeys(), {public_key, public_key}));
+
+  EXPECT_TRUE(last_assembled_report().has_value());
+  EXPECT_EQ(last_assembled_status(),
+            AggregatableReportAssembler::AssemblyStatus::kOk);
+}
+
+TEST_F(AggregatableReportAssemblerTest,
+       AssemblerDeleted_PendingRequestsNotRun) {
+  AggregatableReportRequest request =
+      aggregation_service::CreateExampleRequest();
+  std::vector<url::Origin> processing_origins = request.processing_origins();
+
+  AssembleReport(std::move(request));
+
+  ResetAssembler();
+  EXPECT_EQ(num_assembly_callbacks_run(), 0);
+}
+
+TEST_F(AggregatableReportAssemblerTest,
+       MultipleSimultaneousRequests_BothSucceed) {
+  AggregatableReportRequest request =
+      aggregation_service::CreateExampleRequest();
+
+  std::vector<url::Origin> processing_origins = request.processing_origins();
+  std::vector<PublicKey> public_keys = {
+      aggregation_service::GenerateKey("id123").public_key,
+      aggregation_service::GenerateKey("456abc").public_key};
+
+  absl::optional<AggregatableReport> report =
+      AggregatableReport::Provider().CreateFromRequestAndPublicKeys(
+          aggregation_service::CloneReportRequest(request), public_keys);
+  ASSERT_TRUE(report.has_value());
+  report_provider()->set_report_to_return(std::move(report.value()));
+
+  AssembleReport(aggregation_service::CloneReportRequest(request));
+  AssembleReport(aggregation_service::CloneReportRequest(request));
+
+  EXPECT_EQ(num_assembly_callbacks_run(), 0);
+
+  fetcher()->TriggerPublicKeyResponse(
+      processing_origins[0], public_keys[0],
+      AggregationServiceKeyFetcher::PublicKeyFetchStatus::kOk);
+  fetcher()->TriggerPublicKeyResponse(
+      processing_origins[1], public_keys[1],
+      AggregationServiceKeyFetcher::PublicKeyFetchStatus::kOk);
+
+  EXPECT_FALSE(fetcher()->HasPendingCallbacks());
+  EXPECT_EQ(num_assembly_callbacks_run(), 2);
+
+  EXPECT_EQ(report_provider()->num_calls(), 2);
+
+  EXPECT_TRUE(aggregation_service::ReportRequestsEqual(
+      report_provider()->PreviousRequest(), request));
+  EXPECT_TRUE(aggregation_service::PublicKeysEqual(
+      report_provider()->PreviousPublicKeys(), public_keys));
+
+  EXPECT_TRUE(last_assembled_report().has_value());
+  EXPECT_EQ(last_assembled_status(),
+            AggregatableReportAssembler::AssemblyStatus::kOk);
+}
+
+TEST_F(AggregatableReportAssemblerTest,
+       TooManySimultaneousRequests_ErrorCausedForNewRequests) {
+  std::vector<PublicKey> public_keys = {
+      aggregation_service::GenerateKey("id123").public_key,
+      aggregation_service::GenerateKey("456abc").public_key};
+  absl::optional<AggregatableReport> report =
+      AggregatableReport::Provider().CreateFromRequestAndPublicKeys(
+          aggregation_service::CreateExampleRequest(), std::move(public_keys));
+  ASSERT_TRUE(report.has_value());
+  report_provider()->set_report_to_return(std::move(report.value()));
+
+  for (size_t i = 0; i < AggregatableReportAssembler::kMaxSimultaneousRequests;
+       ++i) {
+    AssembleReport(aggregation_service::CreateExampleRequest());
+  }
+
+  // All requests are still pending.
+  EXPECT_EQ(num_assembly_callbacks_run(), 0);
+
+  // Adding one request too many causes that new request to fail.
+  AssembleReport(aggregation_service::CreateExampleRequest());
+  EXPECT_EQ(num_assembly_callbacks_run(), 1);
+
+  EXPECT_FALSE(last_assembled_report().has_value());
+  EXPECT_EQ(last_assembled_status(),
+            AggregatableReportAssembler::AssemblyStatus::
+                kTooManySimultaneousRequests);
+
+  // But all other requests should remain pending.
+  fetcher()->TriggerPublicKeyResponseForAllOrigins(
+      aggregation_service::GenerateKey("id123").public_key,
+      AggregationServiceKeyFetcher::PublicKeyFetchStatus::kOk);
+
+  EXPECT_TRUE(num_assembly_callbacks_run() ==
+              AggregatableReportAssembler::kMaxSimultaneousRequests + 1);
+}
+
+}  // namespace content
diff --git a/content/browser/aggregation_service/aggregation_service_key_fetcher.h b/content/browser/aggregation_service/aggregation_service_key_fetcher.h
index 2f8538d..acd48ee 100644
--- a/content/browser/aggregation_service/aggregation_service_key_fetcher.h
+++ b/content/browser/aggregation_service/aggregation_service_key_fetcher.h
@@ -60,7 +60,7 @@
       delete;
   AggregationServiceKeyFetcher& operator=(
       const AggregationServiceKeyFetcher& other) = delete;
-  ~AggregationServiceKeyFetcher();
+  virtual ~AggregationServiceKeyFetcher();
 
   // Gets a currently valid public key for `origin` and triggers the `callback`
   // once completed.
@@ -75,8 +75,8 @@
   // available. At encryption time, the fetcher will (uniformly at random) pick
   // one of the public keys to use. This selection should be made independently
   // between reports so that the key choice cannot be used to partition reports
-  // into separate groups of users.
-  void GetPublicKey(const url::Origin& origin, FetchCallback callback);
+  // into separate groups of users. Virtual for mocking in tests.
+  virtual void GetPublicKey(const url::Origin& origin, FetchCallback callback);
 
  private:
   // Called when public keys are received from the storage.
diff --git a/content/browser/aggregation_service/aggregation_service_test_utils.cc b/content/browser/aggregation_service/aggregation_service_test_utils.cc
index 00c0475..608d328 100644
--- a/content/browser/aggregation_service/aggregation_service_test_utils.cc
+++ b/content/browser/aggregation_service/aggregation_service_test_utils.cc
@@ -7,10 +7,12 @@
 #include <stddef.h>
 #include <stdint.h>
 
+#include <ostream>
 #include <string>
 #include <tuple>
 #include <vector>
 
+#include "base/containers/contains.h"
 #include "base/task/thread_pool.h"
 #include "base/threading/sequence_bound.h"
 #include "base/time/time.h"
@@ -48,6 +50,96 @@
   return testing::AssertionSuccess();
 }
 
+using AggregationServicePayload = AggregatableReport::AggregationServicePayload;
+
+testing::AssertionResult AggregatableReportsEqual(
+    const AggregatableReport& expected,
+    const AggregatableReport& actual) {
+  if (expected.payloads().size() != actual.payloads().size()) {
+    return testing::AssertionFailure()
+           << "Expected payloads size " << expected.payloads().size()
+           << ", actual: " << actual.payloads().size();
+  }
+
+  for (size_t i = 0; i < expected.payloads().size(); ++i) {
+    const AggregationServicePayload& expected_payload = expected.payloads()[i];
+    const AggregationServicePayload& actual_payload = actual.payloads()[i];
+
+    if (expected_payload.origin != actual_payload.origin) {
+      return testing::AssertionFailure()
+             << "Expected origin " << expected_payload.origin
+             << " at payload index " << i
+             << ", actual: " << actual_payload.origin;
+    }
+
+    if (expected_payload.payload != actual_payload.payload) {
+      return testing::AssertionFailure()
+             << "Expected payloads at payload index " << i << " to match";
+    }
+
+    if (expected_payload.key_id != actual_payload.key_id) {
+      return testing::AssertionFailure()
+             << "Expected key_id " << expected_payload.key_id
+             << " at payload index " << i
+             << ", actual: " << actual_payload.key_id;
+    }
+  }
+
+  return SharedInfoEqual(expected.shared_info(), actual.shared_info());
+}
+
+testing::AssertionResult ReportRequestsEqual(
+    const AggregatableReportRequest& expected,
+    const AggregatableReportRequest& actual) {
+  if (expected.processing_origins().size() !=
+      actual.processing_origins().size()) {
+    return testing::AssertionFailure()
+           << "Expected processing_origins size "
+           << expected.processing_origins().size()
+           << ", actual: " << actual.processing_origins().size();
+  }
+  for (size_t i = 0; i < expected.processing_origins().size(); ++i) {
+    if (expected.processing_origins()[i] != actual.processing_origins()[i]) {
+      return testing::AssertionFailure()
+             << "Expected processing_origins()[" << i << "] to be "
+             << expected.processing_origins()[i]
+             << ", actual: " << actual.processing_origins()[i];
+    }
+  }
+
+  testing::AssertionResult payload_contents_equal = PayloadContentsEqual(
+      expected.payload_contents(), actual.payload_contents());
+  if (!payload_contents_equal)
+    return payload_contents_equal;
+
+  return SharedInfoEqual(expected.shared_info(), actual.shared_info());
+}
+
+testing::AssertionResult PayloadContentsEqual(
+    const AggregationServicePayloadContents& expected,
+    const AggregationServicePayloadContents& actual) {
+  if (expected.operation != actual.operation) {
+    return testing::AssertionFailure()
+           << "Expected operation " << expected.operation
+           << ", actual: " << actual.operation;
+  }
+  if (expected.bucket != actual.bucket) {
+    return testing::AssertionFailure() << "Expected bucket " << expected.bucket
+                                       << ", actual: " << actual.bucket;
+  }
+  if (expected.value != actual.value) {
+    return testing::AssertionFailure() << "Expected value " << expected.value
+                                       << ", actual: " << actual.value;
+  }
+  if (expected.processing_type != actual.processing_type) {
+    return testing::AssertionFailure()
+           << "Expected processing_type " << expected.processing_type
+           << ", actual: " << actual.processing_type;
+  }
+
+  return testing::AssertionSuccess();
+}
+
 testing::AssertionResult SharedInfoEqual(
     const AggregatableReportSharedInfo& expected,
     const AggregatableReportSharedInfo& actual) {
@@ -82,6 +174,23 @@
       .value();
 }
 
+AggregatableReportRequest CloneReportRequest(
+    const AggregatableReportRequest& request) {
+  return AggregatableReportRequest::Create(request.processing_origins(),
+                                           request.payload_contents(),
+                                           request.shared_info())
+      .value();
+}
+
+AggregatableReport CloneAggregatableReport(const AggregatableReport& report) {
+  std::vector<AggregationServicePayload> payloads;
+  for (const AggregationServicePayload& payload : report.payloads()) {
+    payloads.emplace_back(payload.origin, payload.payload, payload.key_id);
+  }
+
+  return AggregatableReport(std::move(payloads), report.shared_info());
+}
+
 TestHpkeKey GenerateKey(std::string key_id) {
   bssl::ScopedEVP_HPKE_KEY key;
   EXPECT_TRUE(EVP_HPKE_KEY_generate(key.get(), EVP_hpke_x25519_hkdf_sha256()));
@@ -116,4 +225,81 @@
   return storage_;
 }
 
+TestAggregationServiceKeyFetcher::TestAggregationServiceKeyFetcher()
+    : AggregationServiceKeyFetcher(/*manager=*/nullptr,
+                                   /*network_fetcher=*/nullptr) {}
+
+TestAggregationServiceKeyFetcher::~TestAggregationServiceKeyFetcher() = default;
+
+void TestAggregationServiceKeyFetcher::GetPublicKey(const url::Origin& origin,
+                                                    FetchCallback callback) {
+  callbacks_[origin].push_back(std::move(callback));
+}
+
+void TestAggregationServiceKeyFetcher::TriggerPublicKeyResponse(
+    const url::Origin& origin,
+    absl::optional<PublicKey> key,
+    PublicKeyFetchStatus status) {
+  ASSERT_TRUE(base::Contains(callbacks_, origin))
+      << "No corresponding GetPublicKeys call for origin " << origin;
+  ASSERT_EQ(key.has_value(), status == PublicKeyFetchStatus::kOk)
+      << "Key must be returned if and only if status is kOk";
+
+  std::vector<FetchCallback> callbacks = std::move(callbacks_[origin]);
+  callbacks_.erase(origin);
+  for (FetchCallback& callback : callbacks) {
+    std::move(callback).Run(key, status);
+  }
+}
+
+void TestAggregationServiceKeyFetcher::TriggerPublicKeyResponseForAllOrigins(
+    absl::optional<PublicKey> key,
+    PublicKeyFetchStatus status) {
+  std::vector<url::Origin> all_origins_;
+  for (const auto& elem : callbacks_) {
+    all_origins_.push_back(elem.first);
+  }
+  for (auto& origin : all_origins_) {
+    TriggerPublicKeyResponse(std::move(origin), key, status);
+  }
+}
+
+bool TestAggregationServiceKeyFetcher::HasPendingCallbacks() {
+  return !callbacks_.empty();
+}
+
+TestAggregatableReportProvider::TestAggregatableReportProvider() = default;
+TestAggregatableReportProvider::~TestAggregatableReportProvider() = default;
+
+absl::optional<AggregatableReport>
+TestAggregatableReportProvider::CreateFromRequestAndPublicKeys(
+    AggregatableReportRequest report_request,
+    std::vector<PublicKey> public_keys) const {
+  ++num_calls_;
+  previous_request_ = aggregation_service::CloneReportRequest(report_request);
+  previous_public_keys_ = public_keys;
+
+  EXPECT_TRUE(report_to_return_.has_value());
+  return aggregation_service::CloneAggregatableReport(
+      report_to_return_.value());
+}
+
+std::ostream& operator<<(
+    std::ostream& out,
+    const AggregationServicePayloadContents::Operation& operation) {
+  switch (operation) {
+    case AggregationServicePayloadContents::Operation::kCountValueHistogram:
+      return out << "kCountValueHistogram";
+  }
+}
+
+std::ostream& operator<<(
+    std::ostream& out,
+    const AggregationServicePayloadContents::ProcessingType& processing_type) {
+  switch (processing_type) {
+    case AggregationServicePayloadContents::ProcessingType::kTwoParty:
+      return out << "kTwoParty";
+  }
+}
+
 }  // namespace content
diff --git a/content/browser/aggregation_service/aggregation_service_test_utils.h b/content/browser/aggregation_service/aggregation_service_test_utils.h
index ef8d54eb6..1baacc2 100644
--- a/content/browser/aggregation_service/aggregation_service_test_utils.h
+++ b/content/browser/aggregation_service/aggregation_service_test_utils.h
@@ -7,16 +7,21 @@
 
 #include <stdint.h>
 
+#include <map>
+#include <ostream>
 #include <string>
 #include <vector>
 
 #include "base/threading/sequence_bound.h"
 #include "content/browser/aggregation_service/aggregatable_report.h"
 #include "content/browser/aggregation_service/aggregatable_report_manager.h"
+#include "content/browser/aggregation_service/aggregation_service_key_fetcher.h"
 #include "content/browser/aggregation_service/aggregation_service_key_storage.h"
 #include "content/browser/aggregation_service/public_key.h"
 #include "testing/gtest/include/gtest/gtest.h"
+#include "third_party/abseil-cpp/absl/types/optional.h"
 #include "third_party/boringssl/src/include/openssl/hpke.h"
+#include "url/origin.h"
 
 namespace base {
 class Clock;
@@ -36,12 +41,25 @@
 
 testing::AssertionResult PublicKeysEqual(const std::vector<PublicKey>& expected,
                                          const std::vector<PublicKey>& actual);
+testing::AssertionResult AggregatableReportsEqual(
+    const AggregatableReport& expected,
+    const AggregatableReport& actual);
+testing::AssertionResult ReportRequestsEqual(
+    const AggregatableReportRequest& expected,
+    const AggregatableReportRequest& actual);
+testing::AssertionResult PayloadContentsEqual(
+    const AggregationServicePayloadContents& expected,
+    const AggregationServicePayloadContents& actual);
 testing::AssertionResult SharedInfoEqual(
     const AggregatableReportSharedInfo& expected,
     const AggregatableReportSharedInfo& actual);
 
 AggregatableReportRequest CreateExampleRequest();
 
+AggregatableReportRequest CloneReportRequest(
+    const AggregatableReportRequest& request);
+AggregatableReport CloneAggregatableReport(const AggregatableReport& report);
+
 // Generates a public-private key pair for HPKE and also constructs a PublicKey
 // object for use in assembler methods.
 TestHpkeKey GenerateKey(std::string key_id = "example_id");
@@ -69,6 +87,79 @@
   base::SequenceBound<content::AggregationServiceKeyStorage> storage_;
 };
 
+class TestAggregationServiceKeyFetcher : public AggregationServiceKeyFetcher {
+ public:
+  TestAggregationServiceKeyFetcher();
+  ~TestAggregationServiceKeyFetcher() override;
+
+  // AggregationServiceKeyFetcher:
+  void GetPublicKey(const url::Origin& origin, FetchCallback callback) override;
+
+  // Triggers a response for each fetch for `origin`, throwing an error if no
+  // such fetches exist.
+  void TriggerPublicKeyResponse(const url::Origin& origin,
+                                absl::optional<PublicKey> key,
+                                PublicKeyFetchStatus status);
+
+  void TriggerPublicKeyResponseForAllOrigins(absl::optional<PublicKey> key,
+                                             PublicKeyFetchStatus status);
+
+  bool HasPendingCallbacks();
+
+ private:
+  std::map<url::Origin, std::vector<FetchCallback>> callbacks_;
+};
+
+// A simple class for mocking CreateFromRequestAndPublicKeys().
+class TestAggregatableReportProvider : public AggregatableReport::Provider {
+ public:
+  TestAggregatableReportProvider();
+  ~TestAggregatableReportProvider() override;
+
+  absl::optional<AggregatableReport> CreateFromRequestAndPublicKeys(
+      AggregatableReportRequest report_request,
+      std::vector<PublicKey> public_keys) const override;
+
+  int num_calls() const { return num_calls_; }
+
+  const AggregatableReportRequest& PreviousRequest() const {
+    EXPECT_TRUE(previous_request_.has_value());
+    return previous_request_.value();
+  }
+  const std::vector<PublicKey>& PreviousPublicKeys() const {
+    EXPECT_TRUE(previous_request_.has_value());
+    return previous_public_keys_;
+  }
+
+  void set_report_to_return(
+      absl::optional<AggregatableReport> report_to_return) {
+    report_to_return_ = std::move(report_to_return);
+  }
+
+ private:
+  absl::optional<AggregatableReport> report_to_return_;
+
+  // The following are mutable to allow `CreateFromRequestAndPublicKeys()` to be
+  // const.
+
+  // Number of times `CreateFromRequestAndPublicKeys()` is called.
+  mutable int num_calls_ = 0;
+
+  // `absl::nullopt` iff `num_calls_` is 0.
+  mutable absl::optional<AggregatableReportRequest> previous_request_;
+
+  // Empty if `num_calls_` is 0.
+  mutable std::vector<PublicKey> previous_public_keys_;
+};
+
+// Only used for logging in tests.
+std::ostream& operator<<(
+    std::ostream& out,
+    const AggregationServicePayloadContents::Operation& operation);
+std::ostream& operator<<(
+    std::ostream& out,
+    const AggregationServicePayloadContents::ProcessingType& processing_type);
+
 }  // namespace content
 
 #endif  // CONTENT_BROWSER_AGGREGATION_SERVICE_AGGREGATION_SERVICE_TEST_UTILS_H_
diff --git a/content/test/BUILD.gn b/content/test/BUILD.gn
index 6be7496..3adfbd7 100644
--- a/content/test/BUILD.gn
+++ b/content/test/BUILD.gn
@@ -1880,6 +1880,7 @@
     "../browser/accessibility/browser_accessibility_unittest.cc",
     "../browser/accessibility/one_shot_accessibility_tree_search_unittest.cc",
     "../browser/accessibility/touch_passthrough_manager_unittest.cc",
+    "../browser/aggregation_service/aggregatable_report_assembler_unittest.cc",
     "../browser/aggregation_service/aggregatable_report_sender_unittest.cc",
     "../browser/aggregation_service/aggregatable_report_unittest.cc",
     "../browser/aggregation_service/aggregation_service_key_fetcher_unittest.cc",