Skip to content

Commit

Permalink
feat: improved constant propagation (#596)
Browse files Browse the repository at this point in the history
This improves the constant propagation a little so that terms involving
multiple arithmetic constants are always reduced.
  • Loading branch information
DavePearce authored Jan 23, 2025
1 parent a39d8e8 commit fe70003
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 11 deletions.
49 changes: 39 additions & 10 deletions pkg/mir/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,24 +32,26 @@ func applyConstantPropagation(e Expr, schema sc.Schema) Expr {

func applyConstantPropagationAdd(es []Expr, schema sc.Schema) Expr {
sum := fr.NewElement(0)
is_const := true
count := 0
rs := make([]Expr, len(es))
//
for i, e := range es {
rs[i] = applyConstantPropagation(e, schema)
// Check for constant
c, ok := rs[i].(*Constant)
// Try to continue sum
if ok && is_const {
if ok {
sum.Add(&sum, &c.Value)
} else {
is_const = false
// Increase count of constants
count++
}
}
// Check if constant
if is_const {
if count == len(es) {
// Propagate constant
return &Constant{sum}
} else if count > 1 {
rs = mergeConstants(sum, rs)
}
// Done
return &Add{rs}
Expand Down Expand Up @@ -85,10 +87,10 @@ func applyConstantPropagationSub(es []Expr, schema sc.Schema) Expr {

func applyConstantPropagationMul(es []Expr, schema sc.Schema) Expr {
one := fr.NewElement(1)
is_const := true
prod := one
rs := make([]Expr, len(es))
ones := 0
consts := 0
//
for i, e := range es {
rs[i] = applyConstantPropagation(e, schema)
Expand All @@ -100,23 +102,27 @@ func applyConstantPropagationMul(es []Expr, schema sc.Schema) Expr {
return &Constant{c.Value}
} else if ok && c.Value.IsOne() {
ones++
consts++
rs[i] = nil
} else if ok && is_const {
} else if ok {
// Continue building constant
prod.Mul(&prod, &c.Value)
} else {
is_const = false
//
consts++
}
}
// Check if constant
if is_const {
if consts == len(es) {
return &Constant{prod}
} else if ones > 0 {
rs = util.RemoveMatching[Expr](rs, func(item Expr) bool { return item == nil })
}
// Sanity check what's left.
if len(rs) == 1 {
return rs[0]
} else if consts-ones > 1 {
// Combine constants
rs = mergeConstants(prod, rs)
}
// Done
return &Mul{rs}
Expand Down Expand Up @@ -155,3 +161,26 @@ func applyConstantPropagationNorm(arg Expr, schema sc.Schema) Expr {
//
return &Normalise{arg}
}

// Replace all constants within a given sequence of expressions with a single
// constant (whose value has been precomputed from those constants). The new
// value replaces the first constant in the list.
func mergeConstants(constant fr.Element, es []Expr) []Expr {
j := 0
first := true
//
for i := range es {
// Check for constant
if _, ok := es[i].(*Constant); ok && first {
es[j] = &Constant{constant}
first = false
j++
} else if !ok {
// Retain non-constant expression
es[j] = es[i]
j++
}
}
// Return slice
return es[0:j]
}
1 change: 0 additions & 1 deletion pkg/schema/assignment/computation.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,6 @@ func mapIfNativeFunction(trace tr.Trace, sources []uint) []util.FrArray {
rhs := fmt.Sprintf("%v=>%s", ith_row, val.String())
panic(fmt.Sprintf("conflicting values in source map (row %d): %s vs %s", i, lhs, rhs))
} else if !ok {
fmt.Printf("Inserting source key (row %d): %v\n", i, extractIthColumns(i, source_keys))
// Item not previously in map
source_map.Insert(ith_key, ith_value)
}
Expand Down

0 comments on commit fe70003

Please sign in to comment.