-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add FeatureCross Layer #13
Merged
+341
−0
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
c23cae5
Add cross feature interaction layer
abheesht17 b5fa030
Add unit tests
abheesht17 6ea37b6
Add unit tests
abheesht17 9c1cc01
Small change
abheesht17 c2637ab
Fix doc-strings
abheesht17 e884d48
Clean up doc-string example
abheesht17 e00e081
Address comments
abheesht17 f6f88c0
Restore init cloning
abheesht17 59e6e66
Merge branch 'keras-team:main' into dcn-example
abheesht17 ae0b0af
Add missing __init__.py file
abheesht17 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
from typing import Any, Optional, Text, Union | ||
|
||
import keras | ||
from keras import ops | ||
|
||
from keras_rs.src import types | ||
from keras_rs.src.api_export import keras_rs_export | ||
from keras_rs.src.utils.keras_utils import clone_initializer | ||
|
||
|
||
@keras_rs_export("keras_rs.layers.FeatureCross") | ||
class FeatureCross(keras.layers.Layer): | ||
"""FeatureCross layer in Deep & Cross Network (DCN). | ||
|
||
A layer that creates explicit and bounded-degree feature interactions | ||
efficiently. The `call` method accepts two inputs: `x0` contains the | ||
original features; the second input `xi` is the output of the previous | ||
`FeatureCross` layer in the stack, i.e., the i-th `FeatureCross` layer. | ||
For the first `FeatureCross` layer in the stack, `x0 = xi`. | ||
|
||
The output is `x_{i+1} = x0 .* (W * x_i + bias + diag_scale * x_i) + x_i`, | ||
where .* denotes element-wise multiplication. W could be a full-rank | ||
matrix, or a low-rank matrix `U*V` to reduce the computational cost, and | ||
`diag_scale` increases the diagonal of W to improve training stability ( | ||
especially for the low-rank case). | ||
|
||
Args: | ||
projection_dim: int. Dimension for down-projecting the input to reduce | ||
computational cost. If `None` (default), the full matrix, `W` | ||
(with shape `(input_dim, input_dim)`) is used. Otherwise, a low-rank | ||
matrix `W = U*V` will be used, where `U` is of shape | ||
`(input_dim, projection_dim)` and `V` is of shape | ||
`(projection_dim, input_dim)`. `projection_dim` need to be smaller | ||
than `input_dim//2` to improve the model efficiency. In practice, | ||
we've observed that `projection_dim = input_dim//4` consistently | ||
preserved the accuracy of a full-rank version. | ||
diag_scale: non-negative float. Used to increase the diagonal of the | ||
kernel W by `diag_scale`, i.e., `W + diag_scale * I`, where I is the | ||
identity matrix. Defaults to `None`. | ||
use_bias: bool. Whether to add a bias term for this layer. Defaults to | ||
`True`. | ||
pre_activation: string or `keras.activations`. Activation applied to | ||
output matrix of the layer, before multiplication with the input. | ||
Can be used to control the scale of the layer's outputs and | ||
improve stability. Defaults to `None`. | ||
kernel_initializer: string or `keras.initializers` initializer. | ||
Initializer to use for the kernel matrix. Defaults to | ||
`"glorot_uniform"`. | ||
bias_initializer: string or `keras.initializers` initializer. | ||
Initializer to use for the bias vector. Defaults to `"ones"`. | ||
kernel_regularizer: string or `keras.regularizer` regularizer. | ||
Regularizer to use for the kernel matrix. | ||
bias_regularizer: string or `keras.regularizer` regularizer. | ||
Regularizer to use for the bias vector. | ||
|
||
Example: | ||
|
||
```python | ||
# after embedding layer in a functional model | ||
input = keras.Input(shape=(), name='indices', dtype="int64") | ||
x0 = keras.layers.Embedding(input_dim=32, output_dim=6)(x0) | ||
x1 = FeatureCross()(x0, x0) | ||
x2 = FeatureCross()(x0, x1) | ||
logits = keras.layers.Dense(units=10)(x2) | ||
model = keras.Model(input, logits) | ||
``` | ||
|
||
References: | ||
- [R. Wang et al.](https://arxiv.org/abs/2008.13535) | ||
- [R. Wang et al.](https://arxiv.org/abs/1708.05123) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
projection_dim: Optional[int] = None, | ||
diag_scale: Optional[float] = 0.0, | ||
use_bias: bool = True, | ||
pre_activation: Optional[Union[str, keras.layers.Activation]] = None, | ||
kernel_initializer: Union[ | ||
Text, keras.initializers.Initializer | ||
] = "glorot_uniform", | ||
bias_initializer: Union[Text, keras.initializers.Initializer] = "zeros", | ||
kernel_regularizer: Union[ | ||
Text, None, keras.regularizers.Regularizer | ||
] = None, | ||
bias_regularizer: Union[ | ||
Text, None, keras.regularizers.Regularizer | ||
] = None, | ||
**kwargs: Any, | ||
) -> None: | ||
super().__init__(**kwargs) | ||
|
||
# Passed args. | ||
self.projection_dim = projection_dim | ||
self.diag_scale = diag_scale | ||
self.use_bias = use_bias | ||
self.pre_activation = keras.activations.get(pre_activation) | ||
self.kernel_initializer = keras.initializers.get(kernel_initializer) | ||
self.bias_initializer = keras.initializers.get(bias_initializer) | ||
self.kernel_regularizer = keras.regularizers.get(kernel_regularizer) | ||
self.bias_regularizer = keras.regularizers.get(bias_regularizer) | ||
|
||
# Other args. | ||
self.supports_masking = True | ||
|
||
if self.diag_scale is not None and self.diag_scale < 0.0: | ||
raise ValueError( | ||
"`diag_scale` should be non-negative. Received: " | ||
f"`diag_scale={self.diag_scale}`" | ||
) | ||
|
||
def build(self, input_shape: types.TensorShape) -> None: | ||
last_dim = input_shape[-1] | ||
|
||
if self.projection_dim is not None: | ||
self.down_proj_dense = keras.layers.Dense( | ||
units=self.projection_dim, | ||
use_bias=False, | ||
kernel_initializer=clone_initializer(self.kernel_initializer), | ||
kernel_regularizer=self.kernel_regularizer, | ||
dtype=self.dtype_policy, | ||
) | ||
|
||
self.dense = keras.layers.Dense( | ||
units=last_dim, | ||
activation=self.pre_activation, | ||
use_bias=self.use_bias, | ||
kernel_initializer=clone_initializer(self.kernel_initializer), | ||
bias_initializer=clone_initializer(self.bias_initializer), | ||
kernel_regularizer=self.kernel_regularizer, | ||
bias_regularizer=self.bias_regularizer, | ||
dtype=self.dtype_policy, | ||
) | ||
|
||
self.built = True | ||
|
||
def call( | ||
self, x0: types.Tensor, x: Optional[types.Tensor] = None | ||
) -> types.Tensor: | ||
"""Forward pass of the cross layer. | ||
|
||
Args: | ||
x0: a Tensor. The input to the cross layer. N-rank tensor | ||
with shape `(batch_size, ..., input_dim)`. | ||
x: a Tensor. Optional. If provided, the layer will compute | ||
crosses between x0 and x. Otherwise, the layer will | ||
compute crosses between x0 and itself. Should have the same | ||
shape as `x0`. | ||
|
||
Returns: | ||
Tensor of crosses, with the same shape as `x0`. | ||
""" | ||
|
||
if x is None: | ||
x = x0 | ||
|
||
if x0.shape != x.shape: | ||
raise ValueError( | ||
"`x0` and `x` should have the same shape. Received: " | ||
f"`x.shape` = {x.shape}, `x0.shape` = {x0.shape}" | ||
) | ||
|
||
# Project to a lower dimension. | ||
if self.projection_dim is None: | ||
output = x | ||
else: | ||
output = self.down_proj_dense(x) | ||
|
||
output = self.dense(output) | ||
|
||
output = ops.cast(output, self.compute_dtype) | ||
|
||
if self.diag_scale: | ||
output = ops.add(output, ops.multiply(self.diag_scale, x)) | ||
|
||
return ops.add(ops.multiply(x0, output), x) | ||
|
||
def get_config(self) -> dict[str, Any]: | ||
config: dict[str, Any] = super().get_config() | ||
|
||
config.update( | ||
{ | ||
"projection_dim": self.projection_dim, | ||
"diag_scale": self.diag_scale, | ||
"use_bias": self.use_bias, | ||
"pre_activation": keras.activations.serialize( | ||
self.pre_activation | ||
), | ||
"kernel_initializer": keras.initializers.serialize( | ||
self.kernel_initializer | ||
), | ||
"bias_initializer": keras.initializers.serialize( | ||
self.bias_initializer | ||
), | ||
"kernel_regularizer": keras.regularizers.serialize( | ||
self.kernel_regularizer | ||
), | ||
"bias_regularizer": keras.regularizers.serialize( | ||
self.bias_regularizer | ||
), | ||
} | ||
) | ||
|
||
return config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
import keras | ||
from absl.testing import parameterized | ||
from keras import ops | ||
from keras.layers import deserialize | ||
from keras.layers import serialize | ||
|
||
from keras_rs.src import testing | ||
from keras_rs.src.layers.modeling.feature_cross import FeatureCross | ||
|
||
|
||
class FeatureCrossTest(testing.TestCase, parameterized.TestCase): | ||
def setUp(self): | ||
self.x0 = ops.array([[0.1, 0.2, 0.3]], dtype="float32") | ||
self.x = ops.array([[0.4, 0.5, 0.6]], dtype="float32") | ||
self.exp_output = ops.array([[0.55, 0.8, 1.05]]) | ||
|
||
self.one_inp_exp_output = ops.array([[0.16, 0.32, 0.48]]) | ||
|
||
def test_full_layer(self): | ||
layer = FeatureCross(projection_dim=None, kernel_initializer="ones") | ||
output = layer(self.x0, self.x) | ||
|
||
# Test output. | ||
self.assertAllClose(self.exp_output, output) | ||
|
||
# Test which layers have been initialised and their shapes. | ||
# Kernel, bias terms corresponding to dense layer. | ||
self.assertLen(layer.weights, 2, msg="Unexpected number of `weights`") | ||
self.assertEqual(layer.weights[0].shape, (3, 3)) | ||
self.assertEqual(layer.weights[1].shape, (3,)) | ||
|
||
def test_low_rank_layer(self): | ||
layer = FeatureCross(projection_dim=1, kernel_initializer="ones") | ||
output = layer(self.x0, self.x) | ||
|
||
# Test output. | ||
self.assertAllClose(self.exp_output, output) | ||
|
||
# Test which layers have been initialised and their shapes. | ||
# Kernel term corresponding to down projection layer, and kernel, | ||
# bias terms corresponding to dense layer. | ||
self.assertLen(layer.weights, 3, msg="Unexpected number of `weights`") | ||
self.assertEqual(layer.weights[0].shape, (3, 1)) | ||
self.assertEqual(layer.weights[1].shape, (1, 3)) | ||
self.assertEqual(layer.weights[2].shape, (3,)) | ||
|
||
def test_one_input(self): | ||
layer = FeatureCross(projection_dim=None, kernel_initializer="ones") | ||
output = layer(self.x0) | ||
self.assertAllClose(self.one_inp_exp_output, output) | ||
|
||
def test_invalid_input_shapes(self): | ||
x0 = ops.ones((12, 5)) | ||
x = ops.ones((12, 7)) | ||
|
||
layer = FeatureCross() | ||
|
||
with self.assertRaises(ValueError): | ||
layer(x0, x) | ||
|
||
def test_invalid_diag_scale(self): | ||
with self.assertRaises(ValueError): | ||
FeatureCross(diag_scale=-1.0) | ||
|
||
def test_serialization(self): | ||
sampler = FeatureCross(projection_dim=None, pre_activation="swish") | ||
restored = deserialize(serialize(sampler)) | ||
self.assertDictEqual(sampler.get_config(), restored.get_config()) | ||
|
||
def test_diag_scale(self): | ||
layer = FeatureCross( | ||
projection_dim=None, diag_scale=1.0, kernel_initializer="ones" | ||
) | ||
output = layer(self.x0, self.x) | ||
|
||
self.assertAllClose(ops.array([[0.59, 0.9, 1.23]]), output) | ||
|
||
def test_pre_activation(self): | ||
layer = FeatureCross(projection_dim=None, pre_activation=ops.zeros_like) | ||
output = layer(self.x0, self.x) | ||
|
||
self.assertAllClose(self.x, output) | ||
|
||
def test_model_saving(self): | ||
def get_model(): | ||
x0 = keras.layers.Input(shape=(3,)) | ||
x1 = FeatureCross(projection_dim=None)(x0, x0) | ||
x2 = FeatureCross(projection_dim=None)(x0, x1) | ||
logits = keras.layers.Dense(units=1)(x2) | ||
model = keras.Model(x0, logits) | ||
return model | ||
|
||
self.run_model_saving_test( | ||
model=get_model(), | ||
input_data=self.x0, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from typing import Text, Union | ||
|
||
import keras | ||
|
||
|
||
def clone_initializer( | ||
initializer: Union[Text, keras.initializers.Initializer], | ||
) -> keras.initializers.Initializer: | ||
"""Clones an initializer to ensure a new seed. | ||
|
||
As of tensorflow 2.10, we need to clone user passed initializers when | ||
invoking them twice to avoid creating the same randomized initialization. | ||
""" | ||
# If we get a string or dict, just return as we cannot and should not clone. | ||
if not isinstance(initializer, keras.initializers.Initializer): | ||
return initializer | ||
config = initializer.get_config() | ||
return initializer.__class__.from_config(config) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this bug applies to Keras 3 (and if it does, we'll fix it). So remove this file and don't clone the initializers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see this being used everywhere in KerasHub though. Removing it for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh really? Is this something that carried over from Keras 2 and Tensorflow? Or does it still apply to Keras 3 (and if it does, is it for every backend)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hertschuh - looks like it's true for all backends. Here's a short example for JAX: https://colab.research.google.com/drive/1oY7VfaFMztOoMgueOf2TuCz5lYRQ1GXs?resourcekey=0-wnk2cldy6PkkC2qV5PVDmg&usp=sharing.
This means we need
clone_initializer
.