Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/webgpu] ConvTranspose1D slower on Webgpu than Wasm #23273

Open
gianlourbano opened this issue Jan 7, 2025 · 2 comments
Open

[js/webgpu] ConvTranspose1D slower on Webgpu than Wasm #23273

gianlourbano opened this issue Jan 7, 2025 · 2 comments
Labels
ep:WebGPU ort-web webgpu provider platform:web issues related to ONNX Runtime web; typically submitted using template

Comments

@gianlourbano
Copy link

Describe the issue

ConvTranpose1D with input shapes [8, 4098, 435], weights [4096, 1, 4098] strides 1024 and padding 0 appears to be slower on WebGPU than Wasm, with timings:

EP timing (m1 macbook pro)
wasm 6s
webgpu (latest chrome) 30s
webgpu (canary chrome) 18s

canary faster due to this bug

To reproduce

Simple torch script to generate the conv and convert it to onnx

import torch

class ConvTest (torch.nn.Module):
    def __init__(self, weight, stride, padding = 0):
        super(ConvTest, self).__init__()
        self.weight = weight
        self.stride = stride
        self.padding = padding
    
    def forward(self, x):
        return torch.nn.functional.conv_transpose1d(x, self.weight, stride=self.stride, padding=self.padding)

convtest = ConvTest(weight = torch.randn(4098, 1, 4096), stride = 1024)

input = torch.randn(8, 4098,  435)

torch.onnx.export(
    convtest,
    (input,),
    "convtest.onnx",
    input_names=["input"],
    output_names=["output"],
    opset_version=20,
    dynamo=True,
    do_constant_folding=True,
    keep_initializers_as_inputs=True,
    # report=True,
    external_data=None,
    # verify=True
)

To test in browser:

       const session = await ort.InferenceSession.create("/convtest.onnx", {
            executionProviders: ["webgpu"],
            // logSeverityLevel: 0
        });

        const wgpu_profile = []

        ort.env.webgpu.profiling = {
            mode: "default",
            ondata: (data) => {
                wgpu_profile.push(data);
            }
        }

        const input_dims = [8, 4098, 435];
        const size = 8 * 4098 * 435;

        const no_chunks = 1;
        const chunks = [];

        for (let i = 0; i < no_chunks; i++) {
            const chunk = new Float32Array(size);
            chunks.push(chunk);
        }

        for(let i = 0; i < no_chunks; i++) {
            console.time("onnx step " + i);
            const input = new ort.Tensor("float32", chunks[i], input_dims);
            const output = await session.run({input});
            console.timeEnd("onnx step " + i);
        }

        await session.release();

        wgpu_profile.sort((a, b) => (a.endTime-a.startTime) - (b.endTime-b.startTime));

        wgpu_profile.forEach((kernel) => {
            console.log(`${kernel.kernelType} (${kernel.kernelName}) took ${(kernel.endTime - kernel.startTime) / 1000 / 1000} ms`);
        })

Urgency

Urgent

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.21.0-dev.20241224-2d05c4bcd9

Execution Provider

'webgpu' (WebGPU), 'wasm'/'cpu' (WebAssembly CPU)

@gianlourbano gianlourbano added the platform:web issues related to ONNX Runtime web; typically submitted using template label Jan 7, 2025
@gianlourbano
Copy link
Author

@qjia7 @gyagp could you please take a look? Maybe it has something to do with this pr

@github-actions github-actions bot added the ep:WebGPU ort-web webgpu provider label Jan 7, 2025
@qjia7
Copy link
Contributor

qjia7 commented Jan 8, 2025

@gianlourbano I can reproduce it. Will take a look, thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:WebGPU ort-web webgpu provider platform:web issues related to ONNX Runtime web; typically submitted using template
Projects
None yet
Development

No branches or pull requests

2 participants