-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
The Precision Issue of the GELU Operator #5692
Comments
My money is on the fact that triton, by default, uses fast approximations to compute trascendental functions (tanh in this case) while PyTorch uses better (but slower) approximations. But again, without a repro this is just a guess. |
Hi, I am actually using the library https://github.com/BobMcDear/attorch and calling its fused MLP operator. This operator integrates an activation function after the linear layer. For testing, I set the linear layer as an identity mapping and the activation function to GELU. In the original GELU implementation found in @triton.jit
def gelu(input):
cdf = 0.5 * (1 + tl.math.erf(0.707106781 * input))
return cdf * input I noticed some precision issues. So I made some changes to its realization. But even after approximating it with @triton.jit
def tanh(input):
return 2 * tl.sigmoid(2 * input) - 1
@triton.jit
def gelu(input):
input_cube = input * input * input
inner = 0.7978845608 * (input + 0.044715 * input_cube)
cdf = 0.5 * (1 + tanh(inner))
return cdf * input Below is my test code, simply git clone that repository and place the test code at the same directory level as the import torch
import torch.nn as nn
from attorch.attorch import Linear
class FusedMLP(nn.Module):
def __init__(self, hidden_size, intermediate_size):
super().__init__()
self.act = nn.GELU()
self.fc1 = nn.Linear(hidden_size, intermediate_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.act(hidden_states)
return hidden_states
def accuracy_test():
hidden_size = 2
intermediate_size = 4
x = torch.rand(1, 2, device=torch.device('cuda'))
print(x)
torch_linear = FusedMLP(hidden_size, intermediate_size)
attorch_linear = Linear(hidden_size, intermediate_size, act_func='gelu')
with torch.no_grad():
torch_linear.fc1.weight = nn.Parameter(torch.ones(intermediate_size, hidden_size)) # set to 1
torch_linear.fc1.bias = nn.Parameter(torch.zeros(intermediate_size)) # set to 0
torch_linear = torch_linear.to('cuda')
with torch.no_grad():
attorch_linear.weight = nn.Parameter(torch.ones(hidden_size, intermediate_size)) # set to 1
attorch_linear.bias = nn.Parameter(torch.zeros(intermediate_size)) # set to 0
attorch_linear = attorch_linear.to("cuda")
y_torch = torch_linear(x)
y_attorch = attorch_linear(x)
print(torch.allclose(y_attorch, y_torch))
print(y_attorch)
print(y_torch)
accuracy_test() Thanks for your help! |
Here's the PyTorch implementation: |
Describe the issue
When the input is
[1, 1, 1, 1]
withdtype=float32
, the outputs from PyTorch and Triton are as follows:PyTorch output:
tensor([[1.9546, 1.9546, 1.9546, 1.9546]], device='cuda:0')
Triton output:
tensor([[1.9545, 1.9545, 1.9545, 1.9545]], device='cuda:0')
Is this level of precision difference within a reasonable range? Could it be caused by internal mechanisms within Triton?
Environment details
gpu:4090 with 24G
triton 2.1
The text was updated successfully, but these errors were encountered: