Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FeatureCross Layer #13

Merged
merged 10 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras_rs/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
since your modifications would be overwritten.
"""

from keras_rs.src.layers.modeling.feature_cross import FeatureCross
from keras_rs.src.layers.retrieval.brute_force_retrieval import (
BruteForceRetrieval,
)
Empty file.
204 changes: 204 additions & 0 deletions keras_rs/src/layers/modeling/feature_cross.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
from typing import Any, Optional, Text, Union

import keras
from keras import ops

from keras_rs.src import types
from keras_rs.src.api_export import keras_rs_export
from keras_rs.src.utils.keras_utils import clone_initializer


@keras_rs_export("keras_rs.layers.FeatureCross")
class FeatureCross(keras.layers.Layer):
"""FeatureCross layer in Deep & Cross Network (DCN).

A layer that creates explicit and bounded-degree feature interactions
efficiently. The `call` method accepts two inputs: `x0` contains the
original features; the second input `xi` is the output of the previous
`FeatureCross` layer in the stack, i.e., the i-th `FeatureCross` layer.
For the first `FeatureCross` layer in the stack, `x0 = xi`.

The output is `x_{i+1} = x0 .* (W * x_i + bias + diag_scale * x_i) + x_i`,
where .* denotes element-wise multiplication. W could be a full-rank
matrix, or a low-rank matrix `U*V` to reduce the computational cost, and
`diag_scale` increases the diagonal of W to improve training stability (
especially for the low-rank case).

Args:
projection_dim: int. Dimension for down-projecting the input to reduce
computational cost. If `None` (default), the full matrix, `W`
(with shape `(input_dim, input_dim)`) is used. Otherwise, a low-rank
matrix `W = U*V` will be used, where `U` is of shape
`(input_dim, projection_dim)` and `V` is of shape
`(projection_dim, input_dim)`. `projection_dim` need to be smaller
than `input_dim//2` to improve the model efficiency. In practice,
we've observed that `projection_dim = input_dim//4` consistently
preserved the accuracy of a full-rank version.
diag_scale: non-negative float. Used to increase the diagonal of the
kernel W by `diag_scale`, i.e., `W + diag_scale * I`, where I is the
identity matrix. Defaults to `None`.
use_bias: bool. Whether to add a bias term for this layer. Defaults to
`True`.
pre_activation: string or `keras.activations`. Activation applied to
output matrix of the layer, before multiplication with the input.
Can be used to control the scale of the layer's outputs and
improve stability. Defaults to `None`.
kernel_initializer: string or `keras.initializers` initializer.
Initializer to use for the kernel matrix. Defaults to
`"glorot_uniform"`.
bias_initializer: string or `keras.initializers` initializer.
Initializer to use for the bias vector. Defaults to `"ones"`.
kernel_regularizer: string or `keras.regularizer` regularizer.
Regularizer to use for the kernel matrix.
bias_regularizer: string or `keras.regularizer` regularizer.
Regularizer to use for the bias vector.

Example:

```python
# after embedding layer in a functional model
input = keras.Input(shape=(), name='indices', dtype="int64")
x0 = keras.layers.Embedding(input_dim=32, output_dim=6)(x0)
x1 = FeatureCross()(x0, x0)
x2 = FeatureCross()(x0, x1)
logits = keras.layers.Dense(units=10)(x2)
model = keras.Model(input, logits)
```

References:
- [R. Wang et al.](https://arxiv.org/abs/2008.13535)
- [R. Wang et al.](https://arxiv.org/abs/1708.05123)
"""

def __init__(
self,
projection_dim: Optional[int] = None,
diag_scale: Optional[float] = 0.0,
use_bias: bool = True,
pre_activation: Optional[Union[str, keras.layers.Activation]] = None,
kernel_initializer: Union[
Text, keras.initializers.Initializer
] = "glorot_uniform",
bias_initializer: Union[Text, keras.initializers.Initializer] = "zeros",
kernel_regularizer: Union[
Text, None, keras.regularizers.Regularizer
] = None,
bias_regularizer: Union[
Text, None, keras.regularizers.Regularizer
] = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)

# Passed args.
self.projection_dim = projection_dim
self.diag_scale = diag_scale
self.use_bias = use_bias
self.pre_activation = keras.activations.get(pre_activation)
self.kernel_initializer = keras.initializers.get(kernel_initializer)
self.bias_initializer = keras.initializers.get(bias_initializer)
self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)
self.bias_regularizer = keras.regularizers.get(bias_regularizer)

# Other args.
self.supports_masking = True

if self.diag_scale is not None and self.diag_scale < 0.0:
raise ValueError(
"`diag_scale` should be non-negative. Received: "
f"`diag_scale={self.diag_scale}`"
)

def build(self, input_shape: types.TensorShape) -> None:
last_dim = input_shape[-1]

if self.projection_dim is not None:
self.down_proj_dense = keras.layers.Dense(
units=self.projection_dim,
use_bias=False,
kernel_initializer=clone_initializer(self.kernel_initializer),
kernel_regularizer=self.kernel_regularizer,
dtype=self.dtype_policy,
)

self.dense = keras.layers.Dense(
units=last_dim,
activation=self.pre_activation,
use_bias=self.use_bias,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
kernel_regularizer=self.kernel_regularizer,
bias_regularizer=self.bias_regularizer,
dtype=self.dtype_policy,
)

self.built = True

def call(
self, x0: types.Tensor, x: Optional[types.Tensor] = None
) -> types.Tensor:
"""Forward pass of the cross layer.

Args:
x0: a Tensor. The input to the cross layer. N-rank tensor
with shape `(batch_size, ..., input_dim)`.
x: a Tensor. Optional. If provided, the layer will compute
crosses between x0 and x. Otherwise, the layer will
compute crosses between x0 and itself. Should have the same
shape as `x0`.

Returns:
Tensor of crosses, with the same shape as `x0`.
"""

if x is None:
x = x0

if x0.shape != x.shape:
raise ValueError(
"`x0` and `x` should have the same shape. Received: "
f"`x.shape` = {x.shape}, `x0.shape` = {x0.shape}"
)

# Project to a lower dimension.
if self.projection_dim is None:
output = x
else:
output = self.down_proj_dense(x)

output = self.dense(output)

output = ops.cast(output, self.compute_dtype)

if self.diag_scale:
output = ops.add(output, ops.multiply(self.diag_scale, x))

return ops.add(ops.multiply(x0, output), x)

def get_config(self) -> dict[str, Any]:
config: dict[str, Any] = super().get_config()

config.update(
{
"projection_dim": self.projection_dim,
"diag_scale": self.diag_scale,
"use_bias": self.use_bias,
"pre_activation": keras.activations.serialize(
self.pre_activation
),
"kernel_initializer": keras.initializers.serialize(
self.kernel_initializer
),
"bias_initializer": keras.initializers.serialize(
self.bias_initializer
),
"kernel_regularizer": keras.regularizers.serialize(
self.kernel_regularizer
),
"bias_regularizer": keras.regularizers.serialize(
self.bias_regularizer
),
}
)

return config
96 changes: 96 additions & 0 deletions keras_rs/src/layers/modeling/feature_cross_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import keras
from absl.testing import parameterized
from keras import ops
from keras.layers import deserialize
from keras.layers import serialize

from keras_rs.src import testing
from keras_rs.src.layers.modeling.feature_cross import FeatureCross


class FeatureCrossTest(testing.TestCase, parameterized.TestCase):
def setUp(self):
self.x0 = ops.array([[0.1, 0.2, 0.3]], dtype="float32")
self.x = ops.array([[0.4, 0.5, 0.6]], dtype="float32")
self.exp_output = ops.array([[0.55, 0.8, 1.05]])

self.one_inp_exp_output = ops.array([[0.16, 0.32, 0.48]])

def test_full_layer(self):
layer = FeatureCross(projection_dim=None, kernel_initializer="ones")
output = layer(self.x0, self.x)

# Test output.
self.assertAllClose(self.exp_output, output)

# Test which layers have been initialised and their shapes.
# Kernel, bias terms corresponding to dense layer.
self.assertLen(layer.weights, 2, msg="Unexpected number of `weights`")
self.assertEqual(layer.weights[0].shape, (3, 3))
self.assertEqual(layer.weights[1].shape, (3,))

def test_low_rank_layer(self):
layer = FeatureCross(projection_dim=1, kernel_initializer="ones")
output = layer(self.x0, self.x)

# Test output.
self.assertAllClose(self.exp_output, output)

# Test which layers have been initialised and their shapes.
# Kernel term corresponding to down projection layer, and kernel,
# bias terms corresponding to dense layer.
self.assertLen(layer.weights, 3, msg="Unexpected number of `weights`")
self.assertEqual(layer.weights[0].shape, (3, 1))
self.assertEqual(layer.weights[1].shape, (1, 3))
self.assertEqual(layer.weights[2].shape, (3,))

def test_one_input(self):
layer = FeatureCross(projection_dim=None, kernel_initializer="ones")
output = layer(self.x0)
self.assertAllClose(self.one_inp_exp_output, output)

def test_invalid_input_shapes(self):
x0 = ops.ones((12, 5))
x = ops.ones((12, 7))

layer = FeatureCross()

with self.assertRaises(ValueError):
layer(x0, x)

def test_invalid_diag_scale(self):
with self.assertRaises(ValueError):
FeatureCross(diag_scale=-1.0)

def test_serialization(self):
sampler = FeatureCross(projection_dim=None, pre_activation="swish")
restored = deserialize(serialize(sampler))
self.assertDictEqual(sampler.get_config(), restored.get_config())

def test_diag_scale(self):
layer = FeatureCross(
projection_dim=None, diag_scale=1.0, kernel_initializer="ones"
)
output = layer(self.x0, self.x)

self.assertAllClose(ops.array([[0.59, 0.9, 1.23]]), output)

def test_pre_activation(self):
layer = FeatureCross(projection_dim=None, pre_activation=ops.zeros_like)
output = layer(self.x0, self.x)

self.assertAllClose(self.x, output)

def test_model_saving(self):
def get_model():
x0 = keras.layers.Input(shape=(3,))
x1 = FeatureCross(projection_dim=None)(x0, x0)
x2 = FeatureCross(projection_dim=None)(x0, x1)
logits = keras.layers.Dense(units=1)(x2)
model = keras.Model(x0, logits)
return model

self.run_model_saving_test(
model=get_model(),
input_data=self.x0,
)
22 changes: 22 additions & 0 deletions keras_rs/src/testing/test_case.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import os
import tempfile
import unittest
from typing import Any

import keras
import numpy as np
Expand Down Expand Up @@ -54,3 +57,22 @@ def assertAllEqual(
if not isinstance(desired, np.ndarray):
desired = keras.ops.convert_to_numpy(desired)
np.testing.assert_array_equal(actual, desired, err_msg=msg)

def run_model_saving_test(
self,
model: Any,
input_data: Any,
atol: float = 1e-6,
rtol: float = 1e-6,
) -> None:
"""Save and load a model from disk and assert output is unchanged."""
model_output = model(input_data)

with tempfile.TemporaryDirectory() as temp_dir:
path = os.path.join(temp_dir, "model.keras")
model.save(path, save_format="keras_v3")
restored_model = keras.models.load_model(path)

# # Check that output matches.
restored_output = restored_model(input_data)
self.assertAllClose(model_output, restored_output, atol=atol, rtol=rtol)
Empty file added keras_rs/src/utils/__init__.py
Empty file.
18 changes: 18 additions & 0 deletions keras_rs/src/utils/keras_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import Text, Union

import keras


def clone_initializer(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this bug applies to Keras 3 (and if it does, we'll fix it). So remove this file and don't clone the initializers.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see this being used everywhere in KerasHub though. Removing it for now

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh really? Is this something that carried over from Keras 2 and Tensorflow? Or does it still apply to Keras 3 (and if it does, is it for every backend)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hertschuh - looks like it's true for all backends. Here's a short example for JAX: https://colab.research.google.com/drive/1oY7VfaFMztOoMgueOf2TuCz5lYRQ1GXs?resourcekey=0-wnk2cldy6PkkC2qV5PVDmg&usp=sharing.

This means we need clone_initializer.

initializer: Union[Text, keras.initializers.Initializer],
) -> keras.initializers.Initializer:
"""Clones an initializer to ensure a new seed.

As of tensorflow 2.10, we need to clone user passed initializers when
invoking them twice to avoid creating the same randomized initialization.
"""
# If we get a string or dict, just return as we cannot and should not clone.
if not isinstance(initializer, keras.initializers.Initializer):
return initializer
config = initializer.get_config()
return initializer.__class__.from_config(config)
Loading