[go: nahoru, domu]

Skip to content

Commit

Permalink
[SavedModel Fingerprinting] Strip UID from function names in the node…
Browse files Browse the repository at this point in the history
…s of the graphdef during canonicalization.

RFC: tensorflow/community#415
PiperOrigin-RevId: 466798923
  • Loading branch information
Monica Song authored and tensorflower-gardener committed Aug 10, 2022
1 parent 091e1e9 commit c528f87
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 25 deletions.
41 changes: 23 additions & 18 deletions tensorflow/cc/saved_model/fingerprinting.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,34 @@ namespace tensorflow::fingerprinting {

namespace {

// This function mutates the GraphDef, changing the names and config_proto's
// Returns the suffix UID of `function_name`.
StatusOr<int> GetSuffixUID(absl::string_view function_name) {
std::vector<std::string> v = absl::StrSplit(function_name, '_');
int uid;
if (!strings::safe_strto32(v.back(), &uid)) {
return errors::InvalidArgument(absl::StrCat(
"Function name: `", function_name, "` does not end in an integer."));
}
return uid;
}

// This function mutates `graph_def`, changing the names and config_proto's
// of the Function nodes.
void CanonicalizeNodes(GraphDef* orig_graph_def) {
for (NodeDef& node : *orig_graph_def->mutable_node()) {
void CanonicalizeNodes(GraphDef* graph_def) {
for (NodeDef& node : *graph_def->mutable_node()) {
// Check if this is a function call.
if (grappler::IsPartitionedCall(node) ||
grappler::IsStatefulPartitionedCall(node)) {
// TODO(b/240174577): Strip UID from the end of function names.
// Regularize "f" attribute, the function name for PartitionedCall and
// and StatefulPartitionedCall ops.
node.mutable_attr()->find("f")->second.mutable_func()->set_name(
"FINGERPRINT_PASS");
// and StatefulPartitionedCall ops, by stripping the suffix UID if it
// has one.
std::string function_name = node.attr().find("f")->second.func().name();
StatusOr<int> uid = GetSuffixUID(function_name);
if (uid.ok()) {
node.mutable_attr()->find("f")->second.mutable_func()->set_name(
std::string(
absl::StripSuffix(function_name, std::to_string(*uid))));
}
// Erase the "config_proto" attribute which contains device-specific
// information.
node.mutable_attr()->find("config_proto")->second.mutable_s()->erase();
Expand All @@ -68,17 +84,6 @@ void CanonicalizeNodes(GraphDef* orig_graph_def) {
}
}

// Returns the suffix UID of `function_name`.
StatusOr<int> GetSuffixUID(absl::string_view function_name) {
std::vector<std::string> v = absl::StrSplit(function_name, '_');
int uid;
if (!strings::safe_strto32(v.back(), &uid)) {
return errors::InvalidArgument(absl::StrCat(
"Function name: `", function_name, "` does not end in an integer."));
}
return uid;
}

} // namespace

uint64 ComputeHash(const GraphDef& graph_def) {
Expand Down
1 change: 1 addition & 0 deletions tensorflow/cc/saved_model/fingerprinting_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ TEST(FingerprintingTest, TestCreateFingerprint) {
CreateFingerprintDef(saved_model_pb.meta_graphs(0));

EXPECT_GT(fingerprint_def.graph_def_checksum(), 0);
EXPECT_EQ(fingerprint_def.graph_def_program_hash(), 10127142238652115842U);
EXPECT_EQ(fingerprint_def.signature_def_hash(), 5693392539583495303);
EXPECT_EQ(fingerprint_def.saved_object_graph_hash(), 3678101440349108924);
}
Expand Down
12 changes: 6 additions & 6 deletions tensorflow/python/saved_model/fingerprinting_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ def _create_saved_model(self):
self.addCleanup(shutil.rmtree, save_dir)
return save_dir

def _create_saved_model_with_function(self):
def _create_model_with_function(self):
root = autotrackable.AutoTrackable()
root.f = def_function.function(lambda x: 2. * x)
return root

def _create_saved_model_with_input_signature(self):
def _create_model_with_input_signature(self):
root = autotrackable.AutoTrackable()
root.f = def_function.function(
lambda x: 2. * x,
Expand Down Expand Up @@ -75,12 +75,12 @@ def test_basic_module(self):
# We cannot check this value due to non-determinism in serialization.
self.assertGreater(fingerprint_def.graph_def_checksum, 0)
self.assertEqual(fingerprint_def.graph_def_program_hash,
16358308617800096964)
14830488309055091319)
self.assertEqual(fingerprint_def.signature_def_hash, 1050878586713189074)

def test_model_saved_with_different_signature_options(self):
model = self._create_saved_model_with_function()
# Save the model with signatures specified in SaveOptions,.
model = self._create_model_with_function()
# Save the model with signatures specified in SaveOptions.
sig_dir = os.path.join(self.get_temp_dir(), "saved_model")
save.save(
model,
Expand All @@ -92,7 +92,7 @@ def test_model_saved_with_different_signature_options(self):
save.save(model, no_sig_dir)
# Save the model with an input signature specified.
input_sig_dir = os.path.join(self.get_temp_dir(), "saved_model3")
save.save(self._create_saved_model_with_input_signature(), input_sig_dir)
save.save(self._create_model_with_input_signature(), input_sig_dir)

fingerprint_sig = self._read_fingerprint(
file_io.join(sig_dir, constants.FINGERPRINT_FILENAME))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_module_basic(self):
# in serialization.
self.assertGreater(fingerprint_def.graph_def_checksum, 0)
self.assertEqual(fingerprint_def.graph_def_program_hash,
13188891313422428336)
10127142238652115842)
self.assertEqual(fingerprint_def.signature_def_hash, 5693392539583495303)
self.assertEqual(fingerprint_def.saved_object_graph_hash,
3678101440349108924)
Expand Down

0 comments on commit c528f87

Please sign in to comment.