[go: nahoru, domu]

Implement fetching public keys from network

The network fetcher fetches keys from helper servers and parses the
downloaded JSON using the data decoder service's JSON parser.

Integrate the network fetcher with the fetcher to request from network
on first use and expired keys. The parsed keys are stored to storage.
Also add logic to batch together fetches for the same key.

This CL also introduces the concept of a Keyset that includes the fetch
and expiry times.

Bug: 1217823
Change-Id: I4d6915b57f9e3e301f09ca4a13db6dcc49bbc197
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/3045682
Reviewed-by: Alex Turner <alexmt@chromium.org>
Reviewed-by: Bret Sepulveda <bsep@chromium.org>
Reviewed-by: Robert Sesek <rsesek@chromium.org>
Reviewed-by: John Delaney <johnidel@chromium.org>
Commit-Queue: Nan Lin <linnan@chromium.org>
Cr-Commit-Position: refs/heads/main@{#924434}
diff --git a/content/browser/aggregation_service/aggregation_service_key_fetcher.cc b/content/browser/aggregation_service/aggregation_service_key_fetcher.cc
index 48ff379..d4b5ae33 100644
--- a/content/browser/aggregation_service/aggregation_service_key_fetcher.cc
+++ b/content/browser/aggregation_service/aggregation_service_key_fetcher.cc
@@ -4,10 +4,13 @@
 
 #include "content/browser/aggregation_service/aggregation_service_key_fetcher.h"
 
+#include <memory>
 #include <utility>
+#include <vector>
 
 #include "base/bind.h"
 #include "base/callback.h"
+#include "base/containers/circular_deque.h"
 #include "base/rand_util.h"
 #include "content/browser/aggregation_service/aggregatable_report_manager.h"
 #include "content/browser/aggregation_service/aggregation_service_key_storage.h"
@@ -18,8 +21,9 @@
 namespace content {
 
 AggregationServiceKeyFetcher::AggregationServiceKeyFetcher(
-    AggregatableReportManager* manager)
-    : manager_(manager) {}
+    AggregatableReportManager* manager,
+    std::unique_ptr<NetworkFetcher> network_fetcher)
+    : manager_(manager), network_fetcher_(std::move(network_fetcher)) {}
 
 AggregationServiceKeyFetcher::~AggregationServiceKeyFetcher() = default;
 
@@ -31,31 +35,104 @@
     return;
   }
 
+  base::circular_deque<FetchCallback>& pending_callbacks =
+      origin_callbacks_[origin];
+  bool in_progress = !pending_callbacks.empty();
+  pending_callbacks.push_back(std::move(callback));
+
+  // If there is already a fetch request in progress, just enqueue the
+  // callback and return.
+  if (in_progress)
+    return;
+
+  // First we check if we already have keys stored.
+  // TODO(crbug.com/1223488): Pass origin by value and move after C++17.
   manager_->GetKeyStorage()
       .AsyncCall(&AggregationServiceKeyStorage::GetPublicKeys)
       .WithArgs(origin)
       .Then(base::BindOnce(
           &AggregationServiceKeyFetcher::OnPublicKeysReceivedFromStorage,
-          weak_factory_.GetWeakPtr(), std::move(callback)));
+          weak_factory_.GetWeakPtr(), origin));
 }
 
 void AggregationServiceKeyFetcher::OnPublicKeysReceivedFromStorage(
-    FetchCallback callback,
-    PublicKeysForOrigin keys_for_origin) {
-  if (keys_for_origin.keys.empty()) {
-    // TODO(crbug.com/1217823): Fetch public keys from the network.
-
-    std::move(callback).Run(absl::nullopt,
-                            PublicKeyFetchStatus::kPublicKeyFetchFailed);
+    const url::Origin& origin,
+    std::vector<PublicKey> keys) {
+  if (keys.empty()) {
+    // Fetch keys from the network if not found in the storage.
+    FetchPublicKeysFromNetwork(origin);
     return;
   }
 
-  // Each report should randomly select a key. This ensures that the set of
-  // reports a client sends are not a subset of the reports identified by any
-  // one key.
-  uint64_t key_index = base::RandGenerator(keys_for_origin.keys.size());
-  std::move(callback).Run(std::move(keys_for_origin.keys[key_index]),
-                          PublicKeyFetchStatus::kOk);
+  RunCallbacksForOrigin(origin, keys);
+}
+
+void AggregationServiceKeyFetcher::FetchPublicKeysFromNetwork(
+    const url::Origin& origin) {
+  if (!network_fetcher_) {
+    // Return error if fetching from network is not enabled.
+    RunCallbacksForOrigin(origin, /*keys=*/{});
+    return;
+  }
+
+  // Unretained is safe because the network fetcher is owned by `this` and will
+  // be deleted before `this`.
+  network_fetcher_->FetchPublicKeys(
+      origin,
+      base::BindOnce(
+          &AggregationServiceKeyFetcher::OnPublicKeysReceivedFromNetwork,
+          base::Unretained(this), origin));
+}
+
+void AggregationServiceKeyFetcher::OnPublicKeysReceivedFromNetwork(
+    const url::Origin& origin,
+    absl::optional<PublicKeyset> keyset) {
+  if (!keyset.has_value() || keyset->expiry_time.is_null()) {
+    // `keyset` will be absl::nullopt if an error occurred and `expiry_time`
+    // will be null if the freshness lifetime was zero. In these cases, we will
+    // still update the keys for `origin`, i,e. clear them.
+    manager_->GetKeyStorage()
+        .AsyncCall(&AggregationServiceKeyStorage::ClearPublicKeys)
+        .WithArgs(origin);
+  } else {
+    // Store public keys fetched from network to storage, the old keys will be
+    // deleted from storage.
+    manager_->GetKeyStorage()
+        .AsyncCall(&AggregationServiceKeyStorage::SetPublicKeys)
+        .WithArgs(origin, keyset.value());
+  }
+
+  RunCallbacksForOrigin(
+      origin, keyset.has_value() ? keyset->keys : std::vector<PublicKey>());
+}
+
+void AggregationServiceKeyFetcher::RunCallbacksForOrigin(
+    const url::Origin& origin,
+    const std::vector<PublicKey>& keys) {
+  auto iter = origin_callbacks_.find(origin);
+  DCHECK(iter != origin_callbacks_.end());
+
+  base::circular_deque<FetchCallback> pending_callbacks =
+      std::move(iter->second);
+  DCHECK(!pending_callbacks.empty());
+
+  origin_callbacks_.erase(iter);
+
+  if (keys.empty()) {
+    // Return error, don't refetch to avoid infinite loop.
+    for (auto& callback : pending_callbacks) {
+      std::move(callback).Run(absl::nullopt,
+                              PublicKeyFetchStatus::kPublicKeyFetchFailed);
+    }
+  } else {
+    for (auto& callback : pending_callbacks) {
+      // Each report should randomly select a key. This ensures that the set of
+      // reports a client sends are not a subset of the reports identified by
+      // any one key.
+      uint64_t key_index = base::RandGenerator(keys.size());
+      std::move(callback).Run(keys[key_index], PublicKeyFetchStatus::kOk);
+    }
+  }
 }
 
 }  // namespace content