-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add FeatureCross Layer #13
Changes from 6 commits
c23cae5
b5fa030
6ea37b6
9c1cc01
c2637ab
e884d48
e00e081
f6f88c0
59e6e66
ae0b0af
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
from typing import Any, Optional, Text, Union | ||
|
||
import keras | ||
from keras import ops | ||
|
||
from keras_rs.src import types | ||
from keras_rs.src.api_export import keras_rs_export | ||
from keras_rs.src.utils.keras_utils import clone_initializer | ||
|
||
|
||
@keras_rs_export("keras_rs.layers.FeatureCross") | ||
class FeatureCross(keras.layers.Layer): | ||
"""FeatureCross layer in Deep & Cross Network (DCN). | ||
|
||
A layer that creates explicit and bounded-degree feature interactions | ||
efficiently. The `call` method accepts two inputs: `x0` contains the | ||
original features; the second input `xi` is the output of the previous | ||
`FeatureCross` layer in the stack, i.e., the i-th `FeatureCross` layer. | ||
For the first `FeatureCross` layer in the stack, `x0 = xi`. | ||
|
||
The output is `x_{i+1} = x0 .* (W * x_i + bias + diag_scale * x_i) + x_i`, | ||
where .* denotes element-wise multiplication. W could be a full-rank | ||
matrix, or a low-rank matrix `U*V` to reduce the computational cost, and | ||
`diag_scale` increases the diagonal of W to improve training stability ( | ||
especially for the low-rank case). | ||
|
||
References: | ||
- [R. Wang et al.](https://arxiv.org/abs/2008.13535) | ||
- [R. Wang et al.](https://arxiv.org/abs/1708.05123) | ||
|
||
Example: | ||
|
||
```python | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unindent so that it's lined up with |
||
# after embedding layer in a functional model | ||
input = keras.Input(shape=(), name='indices', dtype="int64") | ||
x0 = keras.layers.Embedding(input_dim=32, output_dim=6)(x0) | ||
x1 = FeatureCross()(x0, x0) | ||
x2 = FeatureCross()(x0, x1) | ||
logits = keras.layers.Dense(units=10)(x2) | ||
model = keras.Model(input, logits) | ||
``` | ||
|
||
Args: | ||
projection_dim: int. Dimension for down-projecting the input to reduce | ||
computational cost. If `None` (default), the full matrix, `W` | ||
(with shape `(input_dim, input_dim)`) is used. Otherwise, a low-rank | ||
matrix `W = U*V` will be used, where `U` is of shape | ||
`(input_dim, projection_dim)` and `V` is of shape | ||
`(projection_dim, input_dim)`. `projection_dim` need to be smaller | ||
than `input_dim//2` to improve the model efficiency. In practice, | ||
we've observed that `projection_dim = d/4` consistently preserved | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should it be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep, fixed it |
||
the accuracy of a full-rank version. | ||
diag_scale: non-negative float. Used to increase the diagonal of the | ||
kernel W by `diag_scale`, i.e., `W + diag_scale * I`, where I is the | ||
identity matrix. Defaults to `None`. | ||
use_bias: bool. Whether to add a bias term for this layer. Defaults to | ||
`True`. | ||
pre_activation: string or `keras.activations`. Activation applied to | ||
output matrix of the layer, before multiplication with the input. | ||
Can be used to control the scale of the layer's outputs and | ||
improve stability. Defaults to `None`. | ||
kernel_initializer: string or `keras.initializers` initializer. | ||
Initializer to use for the kernel matrix. Defaults to | ||
`"glorot_uniform"`. | ||
bias_initializer: string or `keras.initializers` initializer. | ||
Initializer to use for the bias vector. Defaults to `"ones"`. | ||
kernel_regularizer: string or `keras.regularizer` regularizer. | ||
Regularizer to use for the kernel matrix. | ||
bias_regularizer: string or `keras.regularizer` regularizer. | ||
Regularizer to use for the bias vector. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
projection_dim: Optional[int] = None, | ||
diag_scale: Optional[float] = 0.0, | ||
use_bias: bool = True, | ||
pre_activation: Optional[Union[str, keras.layers.Activation]] = None, | ||
kernel_initializer: Union[ | ||
Text, keras.initializers.Initializer | ||
] = "glorot_uniform", | ||
bias_initializer: Union[Text, keras.initializers.Initializer] = "zeros", | ||
kernel_regularizer: Union[ | ||
Text, None, keras.regularizers.Regularizer | ||
] = None, | ||
bias_regularizer: Union[ | ||
Text, None, keras.regularizers.Regularizer | ||
] = None, | ||
**kwargs: Any, | ||
) -> None: | ||
super().__init__(**kwargs) | ||
|
||
# Passed args. | ||
self.projection_dim = projection_dim | ||
self.diag_scale = diag_scale | ||
self.use_bias = use_bias | ||
self.pre_activation = keras.activations.get(pre_activation) | ||
self.kernel_initializer = keras.initializers.get(kernel_initializer) | ||
self.bias_initializer = keras.initializers.get(bias_initializer) | ||
self.kernel_regularizer = keras.regularizers.get(kernel_regularizer) | ||
self.bias_regularizer = keras.regularizers.get(bias_regularizer) | ||
|
||
# Other args. | ||
self.supports_masking = True | ||
|
||
if self.diag_scale is not None and self.diag_scale < 0.0: | ||
raise ValueError( | ||
"`diag_scale` should be non-negative. Received: " | ||
f"`diag_scale={self.diag_scale}`" | ||
) | ||
|
||
def build(self, input_shape: types.TensorShape) -> None: | ||
last_dim = input_shape[-1] | ||
|
||
dense_layer_args = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just move this directly line 135, no need for a local variable. |
||
"units": last_dim, | ||
"activation": self.pre_activation, | ||
"use_bias": self.use_bias, | ||
"kernel_initializer": clone_initializer(self.kernel_initializer), | ||
"bias_initializer": clone_initializer(self.bias_initializer), | ||
"kernel_regularizer": self.kernel_regularizer, | ||
"bias_regularizer": self.bias_regularizer, | ||
} | ||
|
||
if self.projection_dim is not None: | ||
self.down_proj_dense = keras.layers.Dense( | ||
units=self.projection_dim, | ||
use_bias=False, | ||
kernel_initializer=clone_initializer(self.kernel_initializer), | ||
kernel_regularizer=self.kernel_regularizer, | ||
dtype=self.dtype_policy, | ||
) | ||
|
||
self.dense = keras.layers.Dense( | ||
**dense_layer_args, | ||
dtype=self.dtype_policy, | ||
) | ||
|
||
self.built = True | ||
|
||
def call( | ||
self, x0: types.Tensor, x: Optional[types.Tensor] = None | ||
) -> types.Tensor: | ||
"""Forward pass of the cross layer. | ||
|
||
Args: | ||
x0: a Tensor. The input to the cross layer. N-rank tensor | ||
with shape `(batch_size, ..., input_dim)`. | ||
x: a Tensor. Optional. If provided, the layer will compute | ||
crosses between x0 and x. Otherwise, the layer will | ||
compute crosses between x0 and itself. Should have the same | ||
shape as `x0`. | ||
|
||
Returns: | ||
Tensor of crosses, with the same shape as `x0`. | ||
""" | ||
|
||
if not self.built: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is already done by Keras, you can remove. |
||
self.build(x0.shape) | ||
|
||
if x is None: | ||
x = x0 | ||
|
||
if x0.shape != x.shape: | ||
raise ValueError( | ||
"`x0` and `x` should have the same shape. Received: " | ||
f"`x.shape` = {x.shape}, `x0.shape` = {x0.shape}" | ||
) | ||
|
||
# Project to a lower dimension. | ||
if self.projection_dim is None: | ||
output = x | ||
else: | ||
output = self.down_proj_dense(x) | ||
|
||
output = self.dense(output) | ||
|
||
output = ops.cast(output, self.compute_dtype) | ||
|
||
if self.diag_scale: | ||
output = output + self.diag_scale * x | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. output = keras.ops.add(output, keras.ops.multiply(self.diag_scale, x)) Just in general, try to use |
||
|
||
return x0 * output + x | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. return keras.ops.add(keras.ops.multiply(x0, output), x) |
||
|
||
def get_config(self) -> Any: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. def get_config(self) -> dict[str, Any]: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did this initially, kept throwing an error. |
||
config = super().get_config() | ||
|
||
config.update( | ||
{ | ||
"projection_dim": self.projection_dim, | ||
"diag_scale": self.diag_scale, | ||
"use_bias": self.use_bias, | ||
"pre_activation": keras.activations.serialize( | ||
self.pre_activation | ||
), | ||
"kernel_initializer": keras.initializers.serialize( | ||
self.kernel_initializer | ||
), | ||
"bias_initializer": keras.initializers.serialize( | ||
self.bias_initializer | ||
), | ||
"kernel_regularizer": keras.regularizers.serialize( | ||
self.kernel_regularizer | ||
), | ||
"bias_regularizer": keras.regularizers.serialize( | ||
self.bias_regularizer | ||
), | ||
} | ||
) | ||
|
||
# Typecast config to `dict`. This is not really needed, | ||
# but `typing` throws an error if we don't do this. | ||
config = dict(config) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm.. that's the issue with using type annotation when Keras doesn't have them. This works and I think it's a bit more elegant: def get_config(self) -> dict[str, Any]:
config: dict[str, Any] = super().get_config()
config.update({
...
})
return config |
||
return config |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import keras | ||
from absl.testing import parameterized | ||
from keras import ops | ||
from keras.layers import deserialize | ||
from keras.layers import serialize | ||
|
||
from keras_rs.src import testing | ||
from keras_rs.src.layers.modeling.feature_cross import FeatureCross | ||
|
||
|
||
class FeatureCrossTest(testing.TestCase, parameterized.TestCase): | ||
def setUp(self): | ||
self.x0 = ops.array([[0.1, 0.2, 0.3]], dtype="float32") | ||
self.x = ops.array([[0.4, 0.5, 0.6]], dtype="float32") | ||
self.exp_output = ops.array([[0.55, 0.8, 1.05]]) | ||
|
||
self.one_inp_exp_output = ops.array([[0.16, 0.32, 0.48]]) | ||
|
||
def test_full_layer(self): | ||
layer = FeatureCross(projection_dim=None, kernel_initializer="ones") | ||
output = layer(self.x0, self.x) | ||
|
||
# Test output. | ||
self.assertAllClose(self.exp_output, output) | ||
|
||
# Test which layers have been initialised and their shapes. | ||
# Kernel, bias terms corresponding to dense layer. | ||
self.assertLen(layer.weights, 2, msg="Unexpected number of `weights`") | ||
self.assertEqual(layer.weights[0].shape, (3, 3)) | ||
self.assertEqual(layer.weights[1].shape, (3,)) | ||
|
||
def test_low_rank_layer(self): | ||
layer = FeatureCross(projection_dim=1, kernel_initializer="ones") | ||
output = layer(self.x0, self.x) | ||
|
||
# Test output. | ||
self.assertAllClose(self.exp_output, output) | ||
|
||
# Test which layers have been initialised and their shapes. | ||
# Kernel term corresponding to down projection layer, and kernel, | ||
# bias terms corresponding to dense layer. | ||
self.assertLen(layer.weights, 3, msg="Unexpected number of `weights`") | ||
self.assertEqual(layer.weights[0].shape, (3, 1)) | ||
self.assertEqual(layer.weights[1].shape, (1, 3)) | ||
self.assertEqual(layer.weights[2].shape, (3,)) | ||
|
||
def test_one_input(self): | ||
layer = FeatureCross(projection_dim=None, kernel_initializer="ones") | ||
output = layer(self.x0) | ||
self.assertAllClose(self.one_inp_exp_output, output) | ||
|
||
def test_invalid_input_shapes(self): | ||
x0 = ops.ones((12, 5)) | ||
x = ops.ones((12, 7)) | ||
|
||
layer = FeatureCross() | ||
|
||
with self.assertRaises(ValueError): | ||
layer(x0, x) | ||
|
||
def test_invalid_diag_scale(self): | ||
with self.assertRaises(ValueError): | ||
FeatureCross(diag_scale=-1.0) | ||
|
||
def test_serialization(self): | ||
sampler = FeatureCross(projection_dim=None, pre_activation="swish") | ||
restored = deserialize(serialize(sampler)) | ||
self.assertDictEqual(sampler.get_config(), restored.get_config()) | ||
|
||
def test_diag_scale(self): | ||
layer = FeatureCross( | ||
projection_dim=None, diag_scale=1.0, kernel_initializer="ones" | ||
) | ||
output = layer(self.x0, self.x) | ||
|
||
self.assertAllClose(ops.array([[0.59, 0.9, 1.23]]), output) | ||
|
||
def test_pre_activation(self): | ||
layer = FeatureCross(projection_dim=None, pre_activation=ops.zeros_like) | ||
output = layer(self.x0, self.x) | ||
|
||
self.assertAllClose(self.x, output) | ||
|
||
def test_saved_model(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. call it |
||
def get_model(): | ||
x0 = keras.layers.Input(shape=(3,)) | ||
x1 = FeatureCross(projection_dim=None)(x0, x0) | ||
x2 = FeatureCross(projection_dim=None)(x0, x1) | ||
logits = keras.layers.Dense(units=1)(x2) | ||
model = keras.Model(x0, logits) | ||
return model | ||
|
||
self.run_model_saving_test( | ||
cls=get_model, | ||
init_kwargs={}, | ||
input_data=self.x0, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,15 @@ | ||
import os | ||
import unittest | ||
from typing import Any, Dict | ||
|
||
import keras | ||
import numpy as np | ||
import tensorflow as tf | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, we don't want to do that, this should stay backend independent. |
||
|
||
from keras_rs.src import types | ||
|
||
|
||
class TestCase(unittest.TestCase): | ||
class TestCase(tf.test.TestCase, unittest.TestCase): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why did you need to add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct. If we want to avoid using |
||
"""TestCase class for all Keras Recommenders tests.""" | ||
|
||
def setUp(self) -> None: | ||
|
@@ -54,3 +57,22 @@ def assertAllEqual( | |
if not isinstance(desired, np.ndarray): | ||
desired = keras.ops.convert_to_numpy(desired) | ||
np.testing.assert_array_equal(actual, desired, err_msg=msg) | ||
|
||
def run_model_saving_test( | ||
self, | ||
cls: Any, | ||
init_kwargs: Dict[Any, Any], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All you do is |
||
input_data: Any, | ||
atol: float = 1e-6, | ||
rtol: float = 1e-6, | ||
) -> None: | ||
"""Save and load a model from disk and assert output is unchanged.""" | ||
model = cls(**init_kwargs) | ||
model_output = model(input_data) | ||
path = os.path.join(self.get_temp_dir(), "model.keras") | ||
model.save(path, save_format="keras_v3") | ||
restored_model = keras.models.load_model(path) | ||
|
||
# # Check that output matches. | ||
restored_output = restored_model(input_data) | ||
self.assertAllClose(model_output, restored_output, atol=atol, rtol=rtol) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from typing import Text, Union | ||
|
||
import keras | ||
|
||
|
||
def clone_initializer( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this bug applies to Keras 3 (and if it does, we'll fix it). So remove this file and don't clone the initializers. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see this being used everywhere in KerasHub though. Removing it for now There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh really? Is this something that carried over from Keras 2 and Tensorflow? Or does it still apply to Keras 3 (and if it does, is it for every backend)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @hertschuh - looks like it's true for all backends. Here's a short example for JAX: https://colab.research.google.com/drive/1oY7VfaFMztOoMgueOf2TuCz5lYRQ1GXs?resourcekey=0-wnk2cldy6PkkC2qV5PVDmg&usp=sharing. This means we need |
||
initializer: Union[Text, keras.initializers.Initializer], | ||
) -> keras.initializers.Initializer: | ||
"""Clones an initializer to ensure a new seed. | ||
|
||
As of tensorflow 2.10, we need to clone user passed initializers when | ||
invoking them twice to avoid creating the same randomized initialization. | ||
""" | ||
# If we get a string or dict, just return as we cannot and should not clone. | ||
if not isinstance(initializer, keras.initializers.Initializer): | ||
return initializer | ||
config = initializer.get_config() | ||
return initializer.__class__.from_config(config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unindent so that it's lined up with
References
.