[go: nahoru, domu]

Http cache: Return valid Content-Range headers for a byte range request.

BUG=12258
TEST=unittests

Review URL: http://codereview.chromium.org/140015

git-svn-id: svn://svn.chromium.org/chrome/trunk/src@18961 0039d316-1c4b-4281-b951-d872f2087c98
diff --git a/net/http/http_cache.cc b/net/http/http_cache.cc
index 44534807..18d928f 100644
--- a/net/http/http_cache.cc
+++ b/net/http/http_cache.cc
@@ -786,6 +786,8 @@
     NOTREACHED();
   }
 
+  partial_->UpdateFromStoredHeaders(response_.headers);
+
   return ContinuePartialCacheValidation();
 }
 
@@ -1074,6 +1076,10 @@
       // TODO(rvargas): Validate partial_content vs partial_ and mode_
       if (partial_content) {
         DCHECK(partial_.get());
+        if (!partial_->ResponseHeadersOK(new_response->headers)) {
+          // TODO(rvargas): Handle this error.
+          NOTREACHED();
+        }
       }
       // Are we expecting a response to a conditional query?
       if (mode_ == READ_WRITE) {
@@ -1103,6 +1109,10 @@
       }
 
       if (!(mode_ & READ)) {
+        // We change the value of Content-Length for partial content.
+        if (partial_content && partial_.get())
+          partial_->FixContentLength(new_response->headers);
+
         response_ = *new_response;
         WriteResponseInfoToEntry();
 
@@ -1123,6 +1133,10 @@
         }
         if (result >= 0 || result == net::ERR_IO_PENDING)
           return;
+      } else if (partial_.get()) {
+        // We are about to return the headers for a byte-range request to the
+        // user, so let's fix them.
+        partial_->FixResponseHeaders(response_.headers);
       }
     }
   } else if (IsCertificateError(result)) {
diff --git a/net/http/http_cache_unittest.cc b/net/http/http_cache_unittest.cc
index 2e0f94da..67b686a4 100644
--- a/net/http/http_cache_unittest.cc
+++ b/net/http/http_cache_unittest.cc
@@ -12,6 +12,7 @@
 #include "net/disk_cache/disk_cache.h"
 #include "net/http/http_byte_range.h"
 #include "net/http/http_request_info.h"
+#include "net/http/http_response_headers.h"
 #include "net/http/http_response_info.h"
 #include "net/http/http_transaction.h"
 #include "net/http/http_transaction_unittest.h"
@@ -351,7 +352,8 @@
 
 void RunTransactionTestWithRequest(net::HttpCache* cache,
                                    const MockTransaction& trans_info,
-                                   const MockHttpRequest& request) {
+                                   const MockHttpRequest& request,
+                                   std::string* response_headers) {
   TestCompletionCallback callback;
 
   // write to the cache
@@ -367,13 +369,23 @@
   const net::HttpResponseInfo* response = trans->GetResponseInfo();
   ASSERT_TRUE(response);
 
+  if (response_headers)
+    response->headers->GetNormalizedHeaders(response_headers);
+
   ReadAndVerifyTransaction(trans.get(), trans_info);
 }
 
 void RunTransactionTest(net::HttpCache* cache,
                         const MockTransaction& trans_info) {
   return RunTransactionTestWithRequest(
-      cache, trans_info, MockHttpRequest(trans_info));
+      cache, trans_info, MockHttpRequest(trans_info), NULL);
+}
+
+void RunTransactionTestWithResponse(net::HttpCache* cache,
+                                    const MockTransaction& trans_info,
+                                    std::string* response_headers) {
+  return RunTransactionTestWithRequest(
+      cache, trans_info, MockHttpRequest(trans_info), response_headers);
 }
 
 // This class provides a handler for kFastNoStoreGET_Transaction so that the
@@ -485,6 +497,35 @@
   0
 };
 
+// Returns true if the response headers (|response|) match a partial content
+// response for the range starting at |start| and ending at |end|.
+bool Verify206Response(std::string response, int start, int end) {
+  std::string raw_headers(net::HttpUtil::AssembleRawHeaders(response.data(),
+                                                            response.size()));
+  scoped_refptr<net::HttpResponseHeaders> headers =
+      new net::HttpResponseHeaders(raw_headers);
+
+  if (206 != headers->response_code())
+    return false;
+
+  int64 range_start, range_end, object_size;
+  if (!headers->GetContentRange(&range_start, &range_end, &object_size))
+    return false;
+  int64 content_length = headers->GetContentLength();
+
+  int length = end - start + 1;
+  if (content_length != length)
+    return false;
+
+  if (range_start != start)
+    return false;
+
+  if (range_end != end)
+    return false;
+
+  return true;
+}
+
 }  // namespace
 
 
@@ -1103,7 +1144,7 @@
   request.upload_data->AppendBytes("hello", 5);
 
   // Populate the cache.
-  RunTransactionTestWithRequest(cache.http_cache(), transaction, request);
+  RunTransactionTestWithRequest(cache.http_cache(), transaction, request, NULL);
 
   EXPECT_EQ(1, cache.network_layer()->transaction_count());
   EXPECT_EQ(0, cache.disk_cache()->open_count());
@@ -1111,7 +1152,7 @@
 
   // Load from cache.
   request.load_flags |= net::LOAD_ONLY_FROM_CACHE;
-  RunTransactionTestWithRequest(cache.http_cache(), transaction, request);
+  RunTransactionTestWithRequest(cache.http_cache(), transaction, request, NULL);
 
   EXPECT_EQ(1, cache.network_layer()->transaction_count());
   EXPECT_EQ(1, cache.disk_cache()->open_count());
@@ -1155,16 +1196,22 @@
   // Test that we can cache range requests and fetch random blocks from the
   // cache and the network.
 
-  // Write to the cache (40-49).
-  RunTransactionTest(cache.http_cache(), kRangeGET_TransactionOK);
+  std::string headers;
 
+  // Write to the cache (40-49).
+  RunTransactionTestWithResponse(cache.http_cache(), kRangeGET_TransactionOK,
+                                 &headers);
+
+  EXPECT_TRUE(Verify206Response(headers, 40, 49));
   EXPECT_EQ(1, cache.network_layer()->transaction_count());
   EXPECT_EQ(0, cache.disk_cache()->open_count());
   EXPECT_EQ(1, cache.disk_cache()->create_count());
 
   // Read from the cache (40-49).
-  RunTransactionTest(cache.http_cache(), kRangeGET_TransactionOK);
+  RunTransactionTestWithResponse(cache.http_cache(), kRangeGET_TransactionOK,
+                                 &headers);
 
+  EXPECT_TRUE(Verify206Response(headers, 40, 49));
   EXPECT_EQ(2, cache.network_layer()->transaction_count());
   EXPECT_EQ(1, cache.disk_cache()->open_count());
   EXPECT_EQ(1, cache.disk_cache()->create_count());
@@ -1176,8 +1223,9 @@
   MockTransaction transaction(kRangeGET_TransactionOK);
   transaction.request_headers = "Range: bytes = 30-39\r\n";
   transaction.data = "rg: 30-39 ";
-  RunTransactionTest(cache.http_cache(), transaction);
+  RunTransactionTestWithResponse(cache.http_cache(), transaction, &headers);
 
+  EXPECT_TRUE(Verify206Response(headers, 30, 39));
   EXPECT_EQ(3, cache.network_layer()->transaction_count());
   EXPECT_EQ(2, cache.disk_cache()->open_count());
   EXPECT_EQ(1, cache.disk_cache()->create_count());
@@ -1188,8 +1236,9 @@
   // Write and read from the cache (20-59).
   transaction.request_headers = "Range: bytes = 20-59\r\n";
   transaction.data = "rg: 20-29 rg: 30-39 rg: 40-49 rg: 50-59 ";
-  RunTransactionTest(cache.http_cache(), transaction);
+  RunTransactionTestWithResponse(cache.http_cache(), transaction, &headers);
 
+  EXPECT_TRUE(Verify206Response(headers, 20, 59));
   EXPECT_EQ(6, cache.network_layer()->transaction_count());
   EXPECT_EQ(3, cache.disk_cache()->open_count());
   EXPECT_EQ(1, cache.disk_cache()->create_count());
diff --git a/net/http/http_response_headers.cc b/net/http/http_response_headers.cc
index f6ff4dff..f86eff8d 100644
--- a/net/http/http_response_headers.cc
+++ b/net/http/http_response_headers.cc
@@ -160,22 +160,22 @@
   DCHECK(new_headers.response_code() == 304 ||
          new_headers.response_code() == 206);
 
-  // copy up to the null byte.  this just copies the status line.
+  // Copy up to the null byte.  This just copies the status line.
   std::string new_raw_headers(raw_headers_.c_str());
   new_raw_headers.push_back('\0');
 
   HeaderSet updated_headers;
 
-  // NOTE: we write the new headers then the old headers for convenience.  the
+  // NOTE: we write the new headers then the old headers for convenience.  The
   // order should not matter.
 
-  // figure out which headers we want to take from new_headers:
+  // Figure out which headers we want to take from new_headers:
   for (size_t i = 0; i < new_headers.parsed_.size(); ++i) {
     const HeaderList& new_parsed = new_headers.parsed_;
 
     DCHECK(!new_parsed[i].is_continuation());
 
-    // locate the start of the next header
+    // Locate the start of the next header.
     size_t k = i;
     while (++k < new_parsed.size() && new_parsed[k].is_continuation());
     --k;
@@ -187,8 +187,8 @@
       StringToLowerASCII(&name);
       updated_headers.insert(name);
 
-      // preserve this header line in the merged result
-      // (including trailing '\0')
+      // Preserve this header line in the merged result, making sure there is
+      // a null after the value.
       new_raw_headers.append(name_begin, new_parsed[k].value_end);
       new_raw_headers.push_back('\0');
     }
@@ -196,19 +196,25 @@
     i = k;
   }
 
-  // now, build the new raw headers
+  // Now, build the new raw headers.
+  MergeWithHeaders(new_raw_headers, updated_headers);
+}
+
+void HttpResponseHeaders::MergeWithHeaders(const std::string& raw_headers,
+                                           const HeaderSet& headers_to_remove) {
+  std::string new_raw_headers(raw_headers);
   for (size_t i = 0; i < parsed_.size(); ++i) {
     DCHECK(!parsed_[i].is_continuation());
 
-    // locate the start of the next header
+    // Locate the start of the next header.
     size_t k = i;
     while (++k < parsed_.size() && parsed_[k].is_continuation());
     --k;
 
     std::string name(parsed_[i].name_begin, parsed_[i].name_end);
     StringToLowerASCII(&name);
-    if (updated_headers.find(name) == updated_headers.end()) {
-      // ok to preserve this header in the final result
+    if (headers_to_remove.find(name) == headers_to_remove.end()) {
+      // It's ok to preserve this header in the final result.
       new_raw_headers.append(parsed_[i].name_begin, parsed_[k].value_end);
       new_raw_headers.push_back('\0');
     }
@@ -217,7 +223,34 @@
   }
   new_raw_headers.push_back('\0');
 
-  // ok, make this object hold the new data
+  // Make this object hold the new data.
+  raw_headers_.clear();
+  parsed_.clear();
+  Parse(new_raw_headers);
+}
+
+void HttpResponseHeaders::RemoveHeader(const std::string& name) {
+  // Copy up to the null byte.  This just copies the status line.
+  std::string new_raw_headers(raw_headers_.c_str());
+  new_raw_headers.push_back('\0');
+
+  std::string lowercase_name(name);
+  StringToLowerASCII(&lowercase_name);
+  HeaderSet to_remove;
+  to_remove.insert(lowercase_name);
+  MergeWithHeaders(new_raw_headers, to_remove);
+}
+
+void HttpResponseHeaders::AddHeader(const std::string& header) {
+  DCHECK_EQ('\0', raw_headers_[raw_headers_.size() - 2]);
+  DCHECK_EQ('\0', raw_headers_[raw_headers_.size() - 1]);
+  // Don't copy the last null.
+  std::string new_raw_headers(raw_headers_, 0, raw_headers_.size() - 1);
+  new_raw_headers.append(header);
+  new_raw_headers.push_back('\0');
+  new_raw_headers.push_back('\0');
+
+  // Make this object hold the new data.
   raw_headers_.clear();
   parsed_.clear();
   Parse(new_raw_headers);
diff --git a/net/http/http_response_headers.h b/net/http/http_response_headers.h
index 1989d076..0552904 100644
--- a/net/http/http_response_headers.h
+++ b/net/http/http_response_headers.h
@@ -61,6 +61,16 @@
   // Performs header merging as described in 13.5.3 of RFC 2616.
   void Update(const HttpResponseHeaders& new_headers);
 
+  // Removes all instances of a particular header.
+  void RemoveHeader(const std::string& name);
+
+  // Adds a particular header.  |header| has to be a single header without any
+  // EOL termination, just [<header-name>: <header-values>]
+  // If a header with the same name is already stored, the two headers are not
+  // merged together by this method; the one provided is simply put at the
+  // end of the list.
+  void AddHeader(const std::string& header);
+
   // Creates a normalized header string.  The output will be formatted exactly
   // like so:
   //     HTTP/<version> <status_code> <status_text>\n
@@ -219,6 +229,8 @@
  private:
   friend class base::RefCountedThreadSafe<HttpResponseHeaders>;
 
+  typedef base::hash_set<std::string> HeaderSet;
+
   HttpResponseHeaders() {}
   ~HttpResponseHeaders() {}
 
@@ -260,7 +272,12 @@
                    std::string::const_iterator value_begin,
                    std::string::const_iterator value_end);
 
-  typedef base::hash_set<std::string> HeaderSet;
+  // Replaces the current headers with the merged version of |raw_headers| and
+  // the current headers without the headers in |headers_to_remove|. Note that
+  // |headers_to_remove| are removed from the current headers (before the
+  // merge), not after the merge.
+  void MergeWithHeaders(const std::string& raw_headers,
+                        const HeaderSet& headers_to_remove);
 
   // Adds the values from any 'cache-control: no-cache="foo,bar"' headers.
   void AddNonCacheableHeaders(HeaderSet* header_names) const;
diff --git a/net/http/http_response_headers_unittest.cc b/net/http/http_response_headers_unittest.cc
index ec26eb9..ad03107 100644
--- a/net/http/http_response_headers_unittest.cc
+++ b/net/http/http_response_headers_unittest.cc
@@ -1414,3 +1414,93 @@
   // HTTP/1.0 200 OK.
   EXPECT_EQ(std::string("OK"), parsed->GetStatusText());
 }
+
+TEST(HttpResponseHeadersTest, AddHeader) {
+  const struct {
+    const char* orig_headers;
+    const char* new_header;
+    const char* expected_headers;
+  } tests[] = {
+    { "HTTP/1.1 200 OK\n"
+      "connection: keep-alive\n"
+      "Cache-control: max-age=10000\n",
+
+      "Content-Length: 450",
+
+      "HTTP/1.1 200 OK\n"
+      "connection: keep-alive\n"
+      "Cache-control: max-age=10000\n"
+      "Content-Length: 450\n"
+    },
+    { "HTTP/1.1 200 OK\n"
+      "connection: keep-alive\n"
+      "Cache-control: max-age=10000    \n",
+
+      "Content-Length: 450  ",
+
+      "HTTP/1.1 200 OK\n"
+      "connection: keep-alive\n"
+      "Cache-control: max-age=10000\n"
+      "Content-Length: 450\n"
+    },
+  };
+
+  for (size_t i = 0; i < ARRAYSIZE_UNSAFE(tests); ++i) {
+    string orig_headers(tests[i].orig_headers);
+    HeadersToRaw(&orig_headers);
+    scoped_refptr<HttpResponseHeaders> parsed =
+        new HttpResponseHeaders(orig_headers);
+
+    string new_header(tests[i].new_header);
+    parsed->AddHeader(new_header);
+
+    string resulting_headers;
+    parsed->GetNormalizedHeaders(&resulting_headers);
+    EXPECT_EQ(string(tests[i].expected_headers), resulting_headers);
+  }
+}
+
+TEST(HttpResponseHeadersTest, RemoveHeader) {
+  const struct {
+    const char* orig_headers;
+    const char* to_remove;
+    const char* expected_headers;
+  } tests[] = {
+    { "HTTP/1.1 200 OK\n"
+      "connection: keep-alive\n"
+      "Cache-control: max-age=10000\n"
+      "Content-Length: 450\n",
+
+      "Content-Length",
+
+      "HTTP/1.1 200 OK\n"
+      "connection: keep-alive\n"
+      "Cache-control: max-age=10000\n"
+    },
+    { "HTTP/1.1 200 OK\n"
+      "connection: keep-alive  \n"
+      "Content-Length  : 450  \n"
+      "Cache-control: max-age=10000\n",
+
+      "Content-Length",
+
+      "HTTP/1.1 200 OK\n"
+      "connection: keep-alive\n"
+      "Cache-control: max-age=10000\n"
+    },
+  };
+
+  for (size_t i = 0; i < ARRAYSIZE_UNSAFE(tests); ++i) {
+    string orig_headers(tests[i].orig_headers);
+    HeadersToRaw(&orig_headers);
+    scoped_refptr<HttpResponseHeaders> parsed =
+        new HttpResponseHeaders(orig_headers);
+
+    string name(tests[i].to_remove);
+    parsed->RemoveHeader(name);
+
+    string resulting_headers;
+    parsed->GetNormalizedHeaders(&resulting_headers);
+    EXPECT_EQ(string(tests[i].expected_headers), resulting_headers);
+  }
+}
diff --git a/net/http/partial_data.cc b/net/http/partial_data.cc
index 44ba086..13cb377ec 100644
--- a/net/http/partial_data.cc
+++ b/net/http/partial_data.cc
@@ -8,8 +8,17 @@
 #include "base/string_util.h"
 #include "net/base/net_errors.h"
 #include "net/disk_cache/disk_cache.h"
+#include "net/http/http_response_headers.h"
 #include "net/http/http_util.h"
 
+namespace {
+
+// The headers that we have to process.
+const char kLengthHeader[] = "Content-Length";
+const char kRangeHeader[] = "Content-Range";
+
+}
+
 namespace net {
 
 bool PartialData::Init(const std::string& headers,
@@ -24,6 +33,7 @@
     return false;
 
   extra_headers_ = new_headers;
+  resource_size_ = 0;
 
   // TODO(rvargas): Handle requests without explicit start or end.
   DCHECK(byte_range_.HasFirstBytePosition());
@@ -92,6 +102,71 @@
   return final_range_;
 }
 
+void PartialData::UpdateFromStoredHeaders(const HttpResponseHeaders* headers) {
+  std::string length_value;
+  if (!headers->GetNormalizedHeader(kLengthHeader, &length_value)) {
+    // We must have stored the resource length.
+    NOTREACHED();
+    resource_size_ = 0;
+    return;
+  }
+  if (!StringToInt64(length_value, &resource_size_)) {
+    NOTREACHED();
+    resource_size_ = 0;
+  }
+}
+
+bool PartialData::ResponseHeadersOK(const HttpResponseHeaders* headers) {
+  int64 start, end, total_length;
+  if (!headers->GetContentRange(&start, &end, &total_length))
+    return false;
+  if (total_length <= 0)
+    return false;
+
+  if (!resource_size_) {
+    // First response. Update our values with the ones provided by the server.
+    resource_size_ = total_length;
+    if (!byte_range_.HasFirstBytePosition())
+      byte_range_.set_first_byte_position(start);
+    if (!byte_range_.HasLastBytePosition())
+      byte_range_.set_last_byte_position(end);
+  } else if (resource_size_ != total_length) {
+    return false;
+  }
+
+  if (start != current_range_start_)
+    return false;
+
+  if (end > byte_range_.last_byte_position())
+    return false;
+
+  return true;
+}
+
+// We are making multiple requests to complete the range requested by the user.
+// Just assume that everything is fine and say that we are returning what was
+// requested.
+void PartialData::FixResponseHeaders(HttpResponseHeaders* headers) {
+  headers->RemoveHeader(kLengthHeader);
+  headers->RemoveHeader(kRangeHeader);
+
+  DCHECK(byte_range_.HasFirstBytePosition());
+  DCHECK(byte_range_.HasLastBytePosition());
+  headers->AddHeader(StringPrintf("%s: bytes %lld-%lld/%lld", kRangeHeader,
+                                  byte_range_.first_byte_position(),
+                                  byte_range_.last_byte_position(),
+                                  resource_size_));
+
+  int64 range_len = byte_range_.last_byte_position() -
+                    byte_range_.first_byte_position() + 1;
+  headers->AddHeader(StringPrintf("%s: %lld", kLengthHeader, range_len));
+}
+
+void PartialData::FixContentLength(HttpResponseHeaders* headers) {
+  headers->RemoveHeader(kLengthHeader);
+  headers->AddHeader(StringPrintf("%s: %lld", kLengthHeader, resource_size_));
+}
+
 int PartialData::CacheRead(disk_cache::Entry* entry, IOBuffer* data,
                            int data_len, CompletionCallback* callback) {
   int read_len = std::min(data_len, cached_min_len_);
diff --git a/net/http/partial_data.h b/net/http/partial_data.h
index ced53abe..5dc1520 100644
--- a/net/http/partial_data.h
+++ b/net/http/partial_data.h
@@ -17,6 +17,7 @@
 
 namespace net {
 
+class HttpResponseHeaders;
 class IOBuffer;
 
 // This class provides support for dealing with range requests and the
@@ -59,6 +60,18 @@
   // user's request.
   bool IsLastRange() const;
 
+  // Extracts info from headers already stored in the cache.
+  void UpdateFromStoredHeaders(const HttpResponseHeaders* headers);
+
+  // Returns true if the response headers match what we expect, false otherwise.
+  bool ResponseHeadersOK(const HttpResponseHeaders* headers);
+
+  // Fixes the response headers to include the right content length and range.
+  void FixResponseHeaders(HttpResponseHeaders* headers);
+
+  // Fixes the content length that we want to store in the cache.
+  void FixContentLength(HttpResponseHeaders* headers);
+
   // Reads up to |data_len| bytes from the cache and stores them in the provided
   // buffer (|data|). Basically, this is just a wrapper around the API of the
   // cache that provides the right arguments for the current range. When the IO
@@ -85,6 +98,7 @@
 
   int64 current_range_start_;
   int64 cached_start_;
+  int64 resource_size_;
   int cached_min_len_;
   HttpByteRange byte_range_;  // The range requested by the user.
   std::string extra_headers_;  // The clean set of extra headers (no ranges).