-
Notifications
You must be signed in to change notification settings - Fork 74.1k
/
generate2_test.py
95 lines (80 loc) · 3.14 KB
/
generate2_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.tools.docs.generate2."""
import os
import pathlib
import shutil
import types
from unittest import mock
from packaging import version
import tensorflow as tf
import yaml
from tensorflow.python.platform import googletest
from tensorflow.tools.docs import generate2
class AutoModule(types.ModuleType):
def __getattr__(self, name):
if name.startswith('_'):
raise AttributeError()
mod = AutoModule(name)
setattr(self, name, mod)
return mod
# Make a mock tensorflow package that won't take too long to test.
fake_tf = AutoModule('FakeTensorFlow')
fake_tf.Module = tf.Module # pylint: disable=invalid-name
fake_tf.feature_column.nummeric_column = tf.feature_column.numeric_column
fake_tf.keras.Model = tf.keras.Model
fake_tf.keras.preprocessing = tf.keras.preprocessing
fake_tf.keras.layers.Layer = tf.keras.layers.Layer
fake_tf.keras.optimizers.Optimizer = tf.keras.optimizers.Optimizer
fake_tf.nn.sigmoid_cross_entropy_with_logits = (
tf.nn.sigmoid_cross_entropy_with_logits
)
fake_tf.raw_ops.Add = tf.raw_ops.Add
fake_tf.raw_ops.Print = tf.raw_ops.Print # op with no XLA support
fake_tf.summary.audio = tf.summary.audio
fake_tf.summary.audio2 = tf.summary.audio
fake_tf.__version__ = tf.__version__
class Generate2Test(googletest.TestCase):
@mock.patch.object(generate2, 'tf', fake_tf)
def test_end_to_end(self):
generate2.MIN_NUM_FILES_EXPECTED = 1
output_dir = pathlib.Path(googletest.GetTempDir())/'output'
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
os.makedirs(output_dir)
generate2.build_docs(
output_dir=output_dir,
code_url_prefix='',
search_hints=True,
)
raw_ops_page = (output_dir/'tf/raw_ops.md').read_text()
self.assertIn('/tf/raw_ops/Add.md', raw_ops_page)
toc = yaml.safe_load((output_dir / 'tf/_toc.yaml').read_text())
self.assertEqual({
'title': 'Overview',
'path': '/tf_overview'
}, toc['toc'][0]['section'][0])
redirects = yaml.safe_load((output_dir / 'tf/_redirects.yaml').read_text())
self.assertIn({'from': '/tf_overview', 'to': '/tf'}, redirects['redirects'])
if version.parse(fake_tf.__version__) >= version.parse('2.14'):
self.assertIn(
'<a id=Add href="/tf/raw_ops/Add.md">Add</a> | ✔️ | ✔️ |', raw_ops_page
)
self.assertIn(
'<a id=Print href="/tf/raw_ops/Print.md">Print</a> | ✔️ | ❌ |',
raw_ops_page,
)
if __name__ == '__main__':
googletest.main()