From d82af58f7d48e6fcc7a1e8d9875a5140e3f9bcf1 Mon Sep 17 00:00:00 2001 From: TensorFlow Lattice Authors Date: Fri, 18 Sep 2020 10:27:37 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 332477482 Change-Id: I4d9653fcc54bd13ddd125d50addb07fe437d69d5 --- docs/_book.yaml | 4 +- tensorflow_lattice/python/linear_layer.py | 40 +++++++++++------- tensorflow_lattice/python/linear_lib.py | 50 ++++++++++++++++------- tensorflow_lattice/python/linear_test.py | 33 ++++++++++++++- 4 files changed, 96 insertions(+), 31 deletions(-) diff --git a/docs/_book.yaml b/docs/_book.yaml index 686b055..c70873c 100644 --- a/docs/_book.yaml +++ b/docs/_book.yaml @@ -36,6 +36,8 @@ upper_tabs: - name: API skip_translation: true contents: - - include: /lattice/api_docs/python/_toc.yaml + - title: All Symbols + path: /lattice/api_docs/python/tfl/all_symbols + - include: /lattice/api_docs/python/tfl/_toc.yaml - include: /_upper_tabs_right.yaml diff --git a/tensorflow_lattice/python/linear_layer.py b/tensorflow_lattice/python/linear_layer.py index 2ea0f19..0a2314b 100644 --- a/tensorflow_lattice/python/linear_layer.py +++ b/tensorflow_lattice/python/linear_layer.py @@ -56,10 +56,11 @@ class Linear(keras.layers.Layer): Weights can be constrained to have a fixed norm. Input shape: - Rank-2 tensor with shape: (batch_size, num_input_dims) + - if `units == 1`: tensor of shape: `(batch_size, num_input_dims)`. + - if `units > 1`: tensor of shape: `(batch_size, units, num_input_dims)` Output shape: - Rank-2 tensor with shape: (batch_size, 1) + Rank-2 tensor with shape: (batch_size, units) Attributes: - All `__init__ `arguments. @@ -83,6 +84,7 @@ class Linear(keras.layers.Layer): def __init__(self, num_input_dims, + units=1, monotonicities=None, monotonic_dominances=None, range_dominances=None, @@ -99,6 +101,7 @@ def __init__(self, Args: num_input_dims: Number of input dimensions. + units: Output dimension of the layer. monotonicities: None or list or tuple of length 'num_input_dims' of {'decreasing', 'none', 'increasing', -1, 0, 1} which specifies if the model output should be monotonic in corresponding feature, using @@ -141,6 +144,7 @@ def __init__(self, super(Linear, self).__init__(**kwargs) self.num_input_dims = num_input_dims + self.units = units if isinstance(monotonicities, list) or isinstance(monotonicities, tuple): self.monotonicities = list(monotonicities) @@ -177,23 +181,27 @@ def __init__(self, for reg in bias_regularizer: self.bias_regularizer.append(keras.regularizers.get(reg)) + if units == 1: + input_shape = (None, num_input_dims) + else: + input_shape = (None, units, num_input_dims) self.input_spec = keras.layers.InputSpec( - dtype=self.dtype, shape=(None, num_input_dims)) + dtype=self.dtype, shape=input_shape) def build(self, input_shape): """Standard Keras build() method. Args: - input_shape: Must be: (batch_size, num_input_dims) + input_shape: Must be: (batch_size, num_input_dims) if units == 1, or + (batch_size, units, num_input_dims) if units > 1. Raises: - ValueError: If shape is not (batch_size, num_input_dims). + ValueError: If shape is invalid. """ - if len(input_shape) != 2 or input_shape[1] != self.num_input_dims: - raise ValueError("'input_shape' must be of rank two and number of " - "elements of second dimension must be equal to " - "'num_input_dims'. 'input_shape': " + str(input_shape) + - "'num_input_dims': " + str(self.num_input_dims)) + linear_lib.verify_hyperparameters( + num_input_dims=self.num_input_dims, + units=self.units, + input_shape=input_shape) if (any(self.monotonicities) or self.monotonic_dominances or self.range_dominances or self.normalization_order): @@ -219,7 +227,7 @@ def build(self, input_shape): self.kernel = self.add_weight( LINEAR_LAYER_KERNEL_NAME, # 1 column matrix rather than verctor for matrix multiplication. - shape=[self.num_input_dims, 1], + shape=[self.num_input_dims, self.units], initializer=self.kernel_initializer, regularizer=kernel_reg, constraint=constraints, @@ -234,7 +242,7 @@ def build(self, input_shape): bias_reg = lambda x: tf.add_n([r(x) for r in self.bias_regularizer]) self.bias = self.add_weight( LINEAR_LAYER_BIAS_NAME, - shape=[], + shape=[] if self.units == 1 else [self.units], initializer=self.bias_initializer, regularizer=bias_reg, constraint=None, @@ -263,7 +271,10 @@ def call(self, inputs): clip_value_min=self.clip_value_min, clip_value_max=self.clip_value_max) - result = tf.matmul(inputs, self.kernel) + if self.units == 1: + result = tf.matmul(inputs, self.kernel) + else: + result = tf.reduce_sum(inputs * tf.transpose(self.kernel), axis=-1) if self.use_bias: result += self.bias return result @@ -271,12 +282,13 @@ def call(self, inputs): def compute_output_shape(self, input_shape): """Standard Keras compute_output_shape() method.""" del input_shape - return [None, 1] + return [None, self.units] def get_config(self): """Standard Keras get_config() method.""" config = { "num_input_dims": self.num_input_dims, + "units": self.units, "monotonicities": self.monotonicities, "use_bias": self.use_bias, "normalization_order": self.normalization_order, diff --git a/tensorflow_lattice/python/linear_lib.py b/tensorflow_lattice/python/linear_lib.py index 8e0f0a7..32462a7 100644 --- a/tensorflow_lattice/python/linear_lib.py +++ b/tensorflow_lattice/python/linear_lib.py @@ -36,7 +36,7 @@ def project(weights, Args: weights: Tensor which represents weights of TFL linear layer. Must have - shape [len(monotonicities), 1]. + shape [len(monotonicities), units]. monotonicities: List or tuple of same length as number of elements in 'weights' of {-1, 0, 1} which represent monotonicity constraints per dimension. -1 stands for decreasing, 0 for no constraints, 1 for @@ -55,7 +55,7 @@ def project(weights, Norm will be computed by: `tf.norm(tensor, ord=normalization_order)`. Raises: - ValueError: If shape of weights is not `(len(monotonicities), 1)`. + ValueError: If shape of weights is not `(len(monotonicities), units)`. Returns: 'weights' with monotonicity constraints and normalization applied to it. @@ -72,7 +72,7 @@ def project(weights, inverted_increasing_mask = tf.constant( value=[0.0 if m == 1 else 1.0 for m in monotonicities], dtype=weights.dtype, - shape=weights.shape) + shape=(weights.shape[0], 1)) # Multiplying by this mask will keep non monotonic dims same and will # set monotonic dims to 0.0. Later by taking maximum with this product # we'll essentially take maximumum of monotonic dims with 0.0. @@ -82,7 +82,7 @@ def project(weights, inverted_decreasing_mask = tf.constant( value=[0.0 if m == -1 else 1.0 for m in monotonicities], dtype=weights.dtype, - shape=weights.shape) + shape=(weights.shape[0], 1)) weights = tf.minimum(weights, weights * inverted_decreasing_mask) if monotonic_dominances: @@ -96,18 +96,17 @@ def project(weights, for dim, (lower, upper) in enumerate(zip(input_min, input_max)): if lower is not None and upper is not None: scalings[dim] *= upper - lower - scalings = tf.constant(scalings, dtype=weights.dtype, shape=weights.shape) + scalings = tf.constant( + scalings, dtype=weights.dtype, shape=(weights.shape[0], 1)) weights *= scalings weights = internal_utils.approximately_project_categorical_partial_monotonicities( weights, range_dominances) weights /= scalings if normalization_order: - norm = tf.norm(weights, ord=normalization_order) - weights = tf.cond( - norm < _NORMALIZATION_EPS, - true_fn=lambda: weights, - false_fn=lambda: weights / norm) + norm = tf.norm(weights, axis=0, ord=normalization_order) + norm = tf.where(norm < _NORMALIZATION_EPS, 1.0, norm) + weights = weights / norm return weights @@ -151,7 +150,7 @@ def assert_constraints(weights, # Create constant specifying shape explicitly because otherwise due to # weights shape ending with dimesion of size 1 broadcasting will hurt us. monotonicities_constant = tf.constant( - monotonicities, shape=weights.shape, dtype=weights.dtype) + monotonicities, shape=(weights.shape[0], 1), dtype=weights.dtype) diff = tf.reduce_min(weights * monotonicities_constant) asserts.append( tf.Assert( @@ -193,7 +192,7 @@ def assert_constraints(weights, summarize=weights.shape[0])) if normalization_order: - norm = tf.norm(weights, ord=normalization_order) + norm = tf.norm(weights, axis=0, ord=normalization_order) asserts.append( # Norm can be either 0.0 or 1.0, because if all weights are close to 0.0 # we can't scale them to get norm 1.0. @@ -210,6 +209,8 @@ def assert_constraints(weights, def verify_hyperparameters(num_input_dims=None, + units=None, + input_shape=None, monotonicities=None, monotonic_dominances=None, range_dominances=None, @@ -230,6 +231,8 @@ def verify_hyperparameters(num_input_dims=None, Args: num_input_dims: None or number of input dimensions. + units: Units hyperparameter of Linear layer. + input_shape: Shape of layer input. monotonicities: List or tuple of same length as number of elements in `weights` of {-1, 0, 1} which represent monotonicity constraints per dimension. -1 stands for decreasing, 0 for no constraints, 1 for @@ -263,9 +266,9 @@ def verify_hyperparameters(num_input_dims=None, (monotonicities, len(monotonicities), num_input_dims)) if weights_shape is not None: - if len(weights_shape) != 2 or weights_shape[1] != 1: - raise ValueError("Expect weights to be a row vector. Weights shape: %s" % - (weights_shape,)) + if len(weights_shape) != 2: + raise ValueError("Expect weights to be a rank 2 tensor. Weights shape: " + "%s" % (weights_shape,)) if monotonicities is not None and weights_shape[0] != len(monotonicities): raise ValueError("Number of elements in 'monotonicities' does not " "correspond to number of weights. Weights shape: %s, " @@ -281,6 +284,23 @@ def verify_hyperparameters(num_input_dims=None, "to number of weights. Weights shape: %s, input_max: %s" % (weights_shape, input_max)) + if input_shape is not None: + assert units is not None and num_input_dims is not None + if (units > 1 and + (len(input_shape) != 3 or input_shape[1] != units or + input_shape[2] != num_input_dims)): + raise ValueError("'input_shape' must be of rank three and number of " + "elements of second and third dimensions must be " + "equal to 'units' and 'num_input_dims' respectively. " + "'input_shape': " + str(input_shape) + "'units': " + + str(units) + "'num_input_dims': " + str(num_input_dims)) + elif (units == 1 and + (len(input_shape) != 2 or input_shape[1] != num_input_dims)): + raise ValueError("'input_shape' must be of rank two and number of " + "elements of second dimension must be equal to " + "'num_input_dims'. 'input_shape': " + str(input_shape) + + "'num_input_dims': " + str(num_input_dims)) + for dim, (lower, upper) in enumerate(zip(input_min or [], input_max or [])): if lower is not None and upper is not None and lower > upper: raise ValueError("Cannot have 'input_min' greater than 'input_max'." diff --git a/tensorflow_lattice/python/linear_test.py b/tensorflow_lattice/python/linear_test.py index 784d468..cbfa250 100644 --- a/tensorflow_lattice/python/linear_test.py +++ b/tensorflow_lattice/python/linear_test.py @@ -100,6 +100,8 @@ def _SetDefaults(self, config): config.setdefault("kernel_regularizer", None) config.setdefault("bias_regularizer", None) config.setdefault("allowed_constraints_violation", 1e-6) + config.setdefault("units", 1) + config.setdefault("unit_index", 0) return config def _GetTrainingInputsAndLabels(self, config): @@ -150,10 +152,24 @@ def _TrainModel(self, config, plot_path=None): training_inputs, training_labels, raw_training_inputs = ( self._GetTrainingInputsAndLabels(config)) + units = config["units"] + num_input_dims = config["num_input_dims"] + if units > 1: + # In order to test multi 'units' linear, replicate inputs 'units' times + # and later use just one out of 'units' outputs in order to ensure that + # multi 'units' linear trains exactly similar to single 'units' one. + training_inputs = [ + np.tile(np.expand_dims(x, axis=0), reps=[units, 1]) + for x in training_inputs + ] + input_shape = (units, num_input_dims) + else: + input_shape = (num_input_dims,) linear_layer = linl.Linear( - input_shape=[config["num_input_dims"]], + input_shape=input_shape, num_input_dims=config["num_input_dims"], + units=units, monotonicities=config["monotonicities"], monotonic_dominances=config["monotonic_dominances"], range_dominances=config["range_dominances"], @@ -170,6 +186,11 @@ def _TrainModel(self, config, plot_path=None): dtype=tf.float32) model = keras.models.Sequential() model.add(linear_layer) + # When we use multi-unit linear, we only extract a single unit for testing. + if units > 1: + unit_index = config["unit_index"] + model.add( + keras.layers.Lambda(lambda x: x[:, unit_index:unit_index + 1])) optimizer = config["optimizer"](learning_rate=config["learning_rate"]) model.compile(loss=keras.losses.mean_squared_error, optimizer=optimizer) @@ -429,6 +450,16 @@ def testTwoDMonotonicity(self, expected_loss, monotonicities): self.assertAlmostEqual(loss, expected_loss, delta=_LOSS_EPS) self.assertAlmostEqual(loss, self._NegateAndTrain(config), delta=_SMALL_EPS) + multioutput_config = dict(config) + units = 3 + multioutput_config["units"] = units + for unit_index in range(units): + multioutput_config["unit_index"] = unit_index + loss = self._TrainModel(multioutput_config) + self.assertAlmostEqual(loss, expected_loss, delta=_LOSS_EPS) + self.assertAlmostEqual( + loss, self._NegateAndTrain(multioutput_config), delta=_SMALL_EPS) + @parameterized.parameters( (1, [0.2, 0.3], 0, 0.250532), # Testing sum of weights < 1.0. (1, [0.2, 0.3], 1, 0.250532), # Monotonicity does not matter here.