Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Differentiation functions with Box<dyn Trait> args fails #193

Open
motabbara opened this issue Jan 18, 2025 · 3 comments
Open

Differentiation functions with Box<dyn Trait> args fails #193

motabbara opened this issue Jan 18, 2025 · 3 comments

Comments

@motabbara
Copy link

Please see https://fwd.gymni.ch/eTJnUQ

Fail on "Attempting to call an indirect active function whose runtime value is inactive".

#![feature(autodiff)]
use std::autodiff::autodiff;
use std::fmt;

#[derive(Debug)]
struct Foo {
    pub test: f64
}

pub trait Cool: fmt::Debug {
    fn gen(&self) -> f64;
}

impl Cool for Foo {
    fn gen(&self) -> f64 {
        self.test * self.test
    }
}


#[autodiff(dsquare, Reverse, Duplicated, Duplicated)]
pub fn square(num: &Foo, result: &mut f64) {
    *result = num.gen()
}

#[autodiff(dsquare2, Reverse, Duplicated, Duplicated)]
pub fn square2(num: &Box<dyn Cool>, result: &mut f64) {
    *result = num.gen()
}

Incidentally, generic functions fail to differentiate even without the box e.g,.,

#[autodiff(dsquare3, Reverse, Duplicated, Duplicated)]
pub fn square3<U: Cool>(num: &U, result: &mut f64) {
    *result = num.gen()
}
@motabbara
Copy link
Author

@ZuseZ4, any recommendations about where in the codebase to look to examine calling traits through Box? Happy to attempt to try something myself with some pointers.

@ZuseZ4
Copy link
Member

ZuseZ4 commented Jan 20, 2025

I'm currently traveling, but I'll be back at my laptop on the 23rd, then I can look closer at the runtime inactivity. In the meantime, if you have a local build, can you run cargo +expand and post the ad macro expansions? Otherwise there might be flags to get the output from the explorer.

Support for Generics should be easy to add, we had support in an earlier implementation. You need to adjust the frontend to not error on generics, and adjust the autodiff function body to call the generic primal function. I will look up the two locations in my frontend pr that you'd need to modify for that.

@motabbara
Copy link
Author

Here it is:

#![feature(prelude_import)]
#![feature(autodiff)]
#[prelude_import]
use std::prelude::rust_2021::*;
#[macro_use]
extern crate std;
use std::autodiff::autodiff;
use std::fmt;
struct Foo {
    pub test: f64,
}
#[automatically_derived]
impl ::core::fmt::Debug for Foo {
    #[inline]
    fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
        ::core::fmt::Formatter::debug_struct_field1_finish(f, "Foo", "test", &&self.test)
    }
}
pub trait Cool: fmt::Debug {
    fn gen(&self) -> f64;
}
impl Cool for Foo {
    fn gen(&self) -> f64 {
        self.test * self.test
    }
}
#[rustc_autodiff]
#[inline(never)]
pub fn square(num: &Foo, result: &mut f64) {
    *result = num.gen();
}
#[rustc_autodiff(Reverse, Duplicated, Duplicated, None)]
#[inline(never)]
pub fn dsquare(num: &Foo, dnum: &mut Foo, result: &mut f64, dresult: &mut f64) {
    unsafe {
        asm!("NOP", options(pure, nomem));
    };
    ::core::hint::black_box(square(num, result));
    ::core::hint::black_box((dnum, dresult));
}
#[rustc_autodiff]
#[inline(never)]
pub fn square2(num: &Box<dyn Cool>, result: &mut f64) {
    *result = num.gen();
}
#[rustc_autodiff(Reverse, Duplicated, Duplicated, None)]
#[inline(never)]
pub fn dsquare2(
    num: &Box<dyn Cool>,
    dnum: &mut Box<dyn Cool>,
    result: &mut f64,
    dresult: &mut f64,
) {
    unsafe {
        asm!("NOP", options(pure, nomem));
    };
    ::core::hint::black_box(square2(num, result));
    ::core::hint::black_box((dnum, dresult));
}
fn main() {
    for i in 0..5 {
        let mut d_foo = Foo { test: 0.0 };
        let f = Foo { test: i as f64 };
        let mut c = 0.0;
        let mut d_c = 1.0;
        let r = dsquare(&f, &mut d_foo, &mut c, &mut d_c);
        {
            ::std::io::_print(format_args!("d_foo {0:?}\n", d_foo));
        };
        let mut d_foo: Box<dyn Cool> = Box::new(Foo { test: 0.0 });
        let f: Box<dyn Cool> = Box::new(Foo { test: i as f64 });
        let r = dsquare2(&f, &mut d_foo, &mut c, &mut d_c);
        {
            ::std::io::_print(format_args!("d_foo {0:?}\n", d_foo));
        };
    }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants