Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement key batching for cassandra online store in go feature server. #165

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
204 changes: 199 additions & 5 deletions go/internal/feast/onlinestore/cassandraonlinestore.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/hex"
"errors"
"fmt"
"math"
"os"
"strings"
"sync"
Expand Down Expand Up @@ -32,6 +33,9 @@ type CassandraOnlineStore struct {
session *gocql.Session

config *registry.RepoConfig

// The number of keys to include in a single CQL query for retrieval from the database
keyBatchSize int
}

type CassandraConfig struct {
Expand All @@ -43,6 +47,7 @@ type CassandraConfig struct {
loadBalancingPolicy gocql.HostSelectionPolicy
connectionTimeoutMillis int64
requestTimeoutMillis int64
keyBatchSize int
}

func parseStringField(config map[string]any, fieldName string, defaultValue string) (string, error) {
Expand Down Expand Up @@ -155,6 +160,13 @@ func extractCassandraConfig(onlineStoreConfig map[string]any) (*CassandraConfig,
}
cassandraConfig.requestTimeoutMillis = int64(requestTimeoutMillis.(float64))

keyBatchSize, ok := onlineStoreConfig["key_batch_size"]
if !ok {
keyBatchSize = 5.0
log.Warn().Msg("key_batch_size not specified, defaulting to batches of size 5")
}
cassandraConfig.keyBatchSize = int(keyBatchSize.(float64))

return &cassandraConfig, nil
}

Expand All @@ -175,8 +187,9 @@ func NewCassandraOnlineStore(project string, config *registry.RepoConfig, online

store.clusterConfigs.PoolConfig.HostSelectionPolicy = cassandraConfig.loadBalancingPolicy

if cassandraConfig.username != "" && cassandraConfig.password != "" {
log.Warn().Msg("username/password not defined, will not be using authentication")
if cassandraConfig.username == "" || cassandraConfig.password == "" {
log.Warn().Msg("username and/or password not defined, will not be using authentication")
} else {
store.clusterConfigs.Authenticator = gocql.PasswordAuthenticator{
Username: cassandraConfig.username,
Password: cassandraConfig.password,
Expand All @@ -202,14 +215,24 @@ func NewCassandraOnlineStore(project string, config *registry.RepoConfig, online
return nil, fmt.Errorf("unable to connect to the ScyllaDB database")
}
store.session = createdSession

if cassandraConfig.keyBatchSize <= 0 || cassandraConfig.keyBatchSize > 100 {
return nil, fmt.Errorf("key_batch_size must be greater than zero and less than 100")
} else if cassandraConfig.keyBatchSize == 1 {
log.Info().Msg("key batching is disabled")
} else {
log.Info().Msgf("key batching is enabled with a batch size of %d", cassandraConfig.keyBatchSize)
}
store.keyBatchSize = cassandraConfig.keyBatchSize

return &store, nil
}

func (c *CassandraOnlineStore) getFqTableName(tableName string) string {
return fmt.Sprintf(`"%s"."%s_%s"`, c.clusterConfigs.Keyspace, c.project, tableName)
}

func (c *CassandraOnlineStore) getCQLStatement(tableName string, featureNames []string) string {
func (c *CassandraOnlineStore) getSingleKeyCQLStatement(tableName string, featureNames []string) string {
// this prevents fetching unnecessary features
quotedFeatureNames := make([]string, len(featureNames))
for i, featureName := range featureNames {
Expand All @@ -223,6 +246,26 @@ func (c *CassandraOnlineStore) getCQLStatement(tableName string, featureNames []
)
}

func (c *CassandraOnlineStore) getMultiKeyCQLStatement(tableName string, featureNames []string, nkeys int) string {
// this prevents fetching unnecessary features
quotedFeatureNames := make([]string, len(featureNames))
for i, featureName := range featureNames {
quotedFeatureNames[i] = fmt.Sprintf(`'%s'`, featureName)
}

keyPlaceholders := make([]string, nkeys)
for i := 0; i < nkeys; i++ {
keyPlaceholders[i] = "?"
}

return fmt.Sprintf(
`SELECT "entity_key", "feature_name", "event_ts", "value" FROM %s WHERE "entity_key" IN (%s) AND "feature_name" IN (%s)`,
tableName,
strings.Join(keyPlaceholders, ","),
strings.Join(quotedFeatureNames, ","),
)
}

func (c *CassandraOnlineStore) buildCassandraEntityKeys(entityKeys []*types.EntityKey) ([]any, map[string]int, error) {
cassandraKeys := make([]any, len(entityKeys))
cassandraKeyToEntityIndex := make(map[string]int)
Expand All @@ -237,7 +280,8 @@ func (c *CassandraOnlineStore) buildCassandraEntityKeys(entityKeys []*types.Enti
}
return cassandraKeys, cassandraKeyToEntityIndex, nil
}
func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*types.EntityKey, featureViewNames []string, featureNames []string) ([][]FeatureData, error) {

func (c *CassandraOnlineStore) UnbatchedKeysOnlineRead(ctx context.Context, entityKeys []*types.EntityKey, featureViewNames []string, featureNames []string) ([][]FeatureData, error) {
uniqueNames := make(map[string]int32)
for _, fvName := range featureViewNames {
uniqueNames[fvName] = 0
Expand Down Expand Up @@ -265,7 +309,7 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ

// Prepare the query
tableName := c.getFqTableName(featureViewName)
cqlStatement := c.getCQLStatement(tableName, featureNames)
cqlStatement := c.getSingleKeyCQLStatement(tableName, featureNames)

var waitGroup sync.WaitGroup
waitGroup.Add(len(serializedEntityKeys))
Expand Down Expand Up @@ -372,6 +416,156 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ
return results, nil
}

func (c *CassandraOnlineStore) BatchedKeysOnlineRead(ctx context.Context, entityKeys []*types.EntityKey, featureViewNames []string, featureNames []string) ([][]FeatureData, error) {
uniqueNames := make(map[string]int32)
for _, fvName := range featureViewNames {
uniqueNames[fvName] = 0
}
if len(uniqueNames) != 1 {
return nil, fmt.Errorf("rejecting OnlineRead as more than 1 feature view was tried to be read at once")
}

serializedEntityKeys, serializedEntityKeyToIndex, err := c.buildCassandraEntityKeys(entityKeys)

if err != nil {
return nil, fmt.Errorf("error when serializing entity keys for Cassandra")
}
results := make([][]FeatureData, len(entityKeys))
for i := range results {
results[i] = make([]FeatureData, len(featureNames))
}

featureNamesToIdx := make(map[string]int)
for idx, name := range featureNames {
featureNamesToIdx[name] = idx
}

featureViewName := featureViewNames[0]

// Prepare the query
tableName := c.getFqTableName(featureViewName)

// Key batching
nKeys := len(serializedEntityKeys)
batchSize := c.keyBatchSize
nBatches := int(math.Ceil(float64(nKeys) / float64(batchSize)))

batches := make([][]any, nBatches)
nAssigned := 0
for i := 0; i < nBatches; i++ {
thisBatchSize := int(math.Min(float64(batchSize), float64(nKeys-nAssigned)))
nAssigned += thisBatchSize
batches[i] = make([]any, thisBatchSize)
for j := 0; j < thisBatchSize; j++ {
batches[i][j] = serializedEntityKeys[i*batchSize+j]
}
}

var waitGroup sync.WaitGroup
waitGroup.Add(nBatches)

errorsChannel := make(chan error, nBatches)
var prevBatchLength int
var cqlStatement string
for _, batch := range batches {
go func(keyBatch []any) {
defer waitGroup.Done()

// this caches the previous batch query if it had the same number of keys
if len(keyBatch) != prevBatchLength {
cqlStatement = c.getMultiKeyCQLStatement(tableName, featureNames, len(keyBatch))
}

iter := c.session.Query(cqlStatement, keyBatch...).WithContext(ctx).Iter()

scanner := iter.Scanner()
var entityKey string
var featureName string
var eventTs time.Time
var valueStr []byte
var deserializedValue types.Value
// key 1: entityKey - key 2: featureName
batchFeatures := make(map[string]map[string]FeatureData)
for scanner.Next() {
err := scanner.Scan(&entityKey, &featureName, &eventTs, &valueStr)
if err != nil {
errorsChannel <- errors.New("could not read row in query for (entity key, feature name, value, event ts)")
return
}
if err := proto.Unmarshal(valueStr, &deserializedValue); err != nil {
errorsChannel <- errors.New("error converting parsed Cassandra Value to types.Value")
return
}

if deserializedValue.Val != nil {
if batchFeatures[entityKey] == nil {
batchFeatures[entityKey] = make(map[string]FeatureData)
}
batchFeatures[entityKey][featureName] = FeatureData{
Reference: serving.FeatureReferenceV2{
FeatureViewName: featureViewName,
FeatureName: featureName,
},
Timestamp: timestamppb.Timestamp{Seconds: eventTs.Unix(), Nanos: int32(eventTs.Nanosecond())},
Value: types.Value{
Val: deserializedValue.Val,
},
}
}
}

if err := scanner.Err(); err != nil {
errorsChannel <- errors.New("failed to scan features: " + err.Error())
return
}

for _, serializedEntityKey := range keyBatch {
for _, featName := range featureNames {
keyString := serializedEntityKey.(string)
featureData, ok := batchFeatures[keyString][featName]
if !ok {
featureData = FeatureData{
Reference: serving.FeatureReferenceV2{
FeatureViewName: featureViewName,
FeatureName: featName,
},
Value: types.Value{
Val: &types.Value_NullVal{
NullVal: types.Null_NULL,
},
},
}
}
results[serializedEntityKeyToIndex[keyString]][featureNamesToIdx[featName]] = featureData
}
}
}(batch)
}
// wait until all concurrent single-key queries are done
waitGroup.Wait()
close(errorsChannel)

var collectedErrors []error
for err := range errorsChannel {
if err != nil {
collectedErrors = append(collectedErrors, err)
}
}
if len(collectedErrors) > 0 {
return nil, errors.Join(collectedErrors...)
}

return results, nil
}

func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*types.EntityKey, featureViewNames []string, featureNames []string) ([][]FeatureData, error) {
if c.keyBatchSize == 1 {
return c.UnbatchedKeysOnlineRead(ctx, entityKeys, featureViewNames, featureNames)
} else {
return c.BatchedKeysOnlineRead(ctx, entityKeys, featureViewNames, featureNames)
}
}

func (c *CassandraOnlineStore) Destruct() {
c.session.Close()
}
15 changes: 13 additions & 2 deletions go/internal/feast/onlinestore/cassandraonlinestore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,28 @@ func TestGetFqTableName(t *testing.T) {
assert.Equal(t, `"scylladb"."dummy_project_dummy_fv"`, fqTableName)
}

func TestGetCQLStatement(t *testing.T) {
func TestGetSingleKeyCQLStatement(t *testing.T) {
store := CassandraOnlineStore{}
fqTableName := `"scylladb"."dummy_project_dummy_fv"`

cqlStatement := store.getCQLStatement(fqTableName, []string{"feat1", "feat2"})
cqlStatement := store.getSingleKeyCQLStatement(fqTableName, []string{"feat1", "feat2"})
assert.Equal(t,
`SELECT "entity_key", "feature_name", "event_ts", "value" FROM "scylladb"."dummy_project_dummy_fv" WHERE "entity_key" = ? AND "feature_name" IN ('feat1','feat2')`,
cqlStatement,
)
}

func TestGetMultiKeyCQLStatement(t *testing.T) {
store := CassandraOnlineStore{}
fqTableName := `"scylladb"."dummy_project_dummy_fv"`

cqlStatement := store.getMultiKeyCQLStatement(fqTableName, []string{"feat1", "feat2"}, 5)
assert.Equal(t,
`SELECT "entity_key", "feature_name", "event_ts", "value" FROM "scylladb"."dummy_project_dummy_fv" WHERE "entity_key" IN (?,?,?,?,?) AND "feature_name" IN ('feat1','feat2')`,
cqlStatement,
)
}

func TestOnlineRead_RejectsDifferentFeatureViewsInSameRead(t *testing.T) {
store := CassandraOnlineStore{}
_, err := store.OnlineRead(context.TODO(), nil, []string{"fv1", "fv2"}, []string{"feat1", "feat2"})
Expand Down
Loading