Skip to content

Commit

Permalink
Take the gradient of division
Browse files Browse the repository at this point in the history
  • Loading branch information
samestep committed Feb 1, 2024
1 parent 954133c commit 69e73a1
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 45 deletions.
120 changes: 79 additions & 41 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,62 +9,91 @@ const bwd = "bwd";
class Autodiff {
mod: binaryen.Module;
grad: number[];
fwd: binaryen.ExpressionRef[];
vars: binaryen.Type[];
bwd: binaryen.ExpressionRef[];

constructor(mod: binaryen.Module, grad: number[]) {
constructor(mod: binaryen.Module, grad: number[], vars: binaryen.Type[]) {
this.mod = mod;
this.grad = grad;
this.fwd = [];
this.vars = vars;
this.bwd = [];
}

make(type: binaryen.Type): number {
const index = this.vars.length;
this.vars.push(type);
return index;
}

set(expr: binaryen.ExpressionRef): {
index: number;
expr: binaryen.ExpressionRef;
} {
const index = this.make(binaryen.getExpressionType(expr));
return { index, expr: this.mod.local.set(index, expr) };
}

get(index: binaryen.ExpressionRef): binaryen.ExpressionRef {
return this.mod.local.get(index, this.vars[index]);
}

local(expr: binaryen.ExpressionRef): number {
const info = binaryen.getExpressionInfo(expr);
if (info.id !== binaryen.LocalGetId)
throw Error("Only local.get is supported");
return (info as binaryen.LocalGetInfo).index;
}

binary(info: binaryen.BinaryInfo, grad: number): binaryen.ExpressionRef {
binary(
info: binaryen.BinaryInfo,
z: number,
dz: number,
): binaryen.ExpressionRef {
const x = this.local(info.left);
const y = this.local(info.right);
const dx = this.grad[x];
const dy = this.grad[y];
switch (info.op) {
case binaryen.SubFloat64:
const x = this.local(info.left);
const y = this.local(info.right);
const dx = this.grad[x];
const dy = this.grad[y];
case binaryen.SubFloat64: {
this.bwd.push(
this.mod.local.set(dx, this.mod.f64.add(this.get(dx), this.get(dz))),
this.mod.local.set(dy, this.mod.f64.sub(this.get(dy), this.get(dz))),
);
return this.mod.f64.sub(this.get(x), this.get(y));
}
case binaryen.DivFloat64: {
// this code appears to set `dy` first, using `dx1` before defining it,
// but `this.bwd` will eventually get reversed so it's fine
const dx1 = this.set(this.mod.f64.div(this.get(dz), this.get(y)));
this.bwd.push(
this.mod.local.set(
dx,
this.mod.f64.add(
this.mod.local.get(dx, binaryen.f64),
this.mod.local.get(grad, binaryen.f64),
),
this.mod.f64.add(this.get(dx), this.get(dx1.index)),
),
this.mod.local.set(
dy,
this.mod.f64.sub(
this.mod.local.get(dy, binaryen.f64),
this.mod.local.get(grad, binaryen.f64),
this.get(dy),
this.mod.f64.mul(this.get(dx1.index), this.get(z)),
),
),
dx1.expr,
);
return this.mod.f64.sub(
this.mod.local.get(x, binaryen.f64),
this.mod.local.get(y, binaryen.f64),
);
return this.mod.f64.div(this.get(x), this.get(y));
}
default:
throw Error("Unsupported binary operation");
}
}

expression(
info: binaryen.ExpressionInfo,
grad: number,
y: number,
dy: number,
): binaryen.ExpressionRef {
switch (info.id) {
case binaryen.BinaryId:
return this.binary(info as binaryen.BinaryInfo, grad);
return this.binary(info as binaryen.BinaryInfo, y, dy);
default:
throw Error("Unsupported expression");
}
Expand Down Expand Up @@ -104,23 +133,12 @@ export const autodiff = (mod: binaryen.Module) => {
const ad = new Autodiff(
mod,
params.map((_, i) => params.length + i),
[...bwdParams],
);
const body = ad.expression(
binaryen.getExpressionInfo(f.body),
bwdParams.length,
);
const out = ad.make(f.results);
const grad = ad.make(binaryen.createType(resultsGrad));
const body = ad.expression(binaryen.getExpressionInfo(f.body), out, grad);

const fwdOut = fwdParams.length;
ad.fwd.push(
mod.local.set(fwdOut, body),
mod.tuple.make([
...results.map((_, i) =>
mod.tuple.extract(mod.local.get(fwdOut, f.results), i),
),
...resultsGrad.map(() => mod.f64.const(0)),
mod.i32.const(0),
]),
);
const fwdResult = binaryen.createType([
...results,
...resultsGrad,
Expand All @@ -130,15 +148,35 @@ export const autodiff = (mod: binaryen.Module) => {
fwd,
binaryen.createType(fwdParams),
fwdResult,
[f.results],
mod.block(null, ad.fwd, fwdResult),
ad.vars.slice(fwdParams.length),
mod.block(
null,
[
mod.local.set(out, body),
mod.tuple.make([
...results.map((_, i) =>
mod.tuple.extract(mod.local.get(out, f.results), i),
),
...resultsGrad.map(() => mod.f64.const(0)),
mod.i32.const(0),
]),
],
fwdResult,
),
);
mod.addFunctionExport(fwd, fwd);

const bwdOut = bwdParams.length;
ad.bwd.push(
mod.local.set(
bwdOut,
out,
mod.tuple.make(
results.map((_, i) =>
mod.local.get(params.length + paramsGrad.length + i, binaryen.f64),
),
),
),
mod.local.set(
grad,
mod.tuple.make(
resultsGrad.map((_, i) =>
mod.local.get(
Expand All @@ -160,7 +198,7 @@ export const autodiff = (mod: binaryen.Module) => {
bwd,
binaryen.createType(bwdParams),
bwdResult,
[binaryen.createType(resultsGrad)],
ad.vars.slice(bwdParams.length),
mod.block(null, ad.bwd, bwdResult),
);
mod.addFunctionExport(bwd, bwd);
Expand Down
39 changes: 35 additions & 4 deletions src/test/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,41 @@ test("subtraction", async () => {
});

test("division", async () => {
const { div } = await compile<{ div: (a: number, b: number) => number }>(
await wat(await slurp("div.wat")),
);
expect(div(2, 3)).toBe(2 / 3);
let binary;
const mod = binaryen.parseText(await slurp("div.wat"));
try {
mod.setFeatures(binaryen.Features.Multivalue);
autodiff(mod);
binary = mod.emitBinary();
} finally {
mod.dispose();
}
const { fwd, bwd } = await compile<{
fwd: (
a: number,
b: number,
da: number,
db: number,
) => [number, number, number];
bwd: (
a: number,
b: number,
da: number,
db: number,
c: number,
dc: number,
t: number,
) => [number, number];
}>(binary);
const a = 5;
const b = 3;
let da = 0;
let db = 0;
let [c, dc, t] = fwd(a, b, da, db);
expect([c, dc, t]).toEqual([5 / 3, 0, 0]);
dc = 1;
[da, db] = bwd(a, b, da, db, c, dc, t);
expect([da, db]).toEqual([1 / 3, -5 / 9]);
});

test("multiple memories", async () => {
Expand Down

0 comments on commit 69e73a1

Please sign in to comment.