Skip to content

Conversation

@dongxianzhe
Copy link
Contributor

optimize layernorm kernel using half2 type
test layernorm kernel

}

template <>
void invoke_layernorm_kernel<half2>(half2* out,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds this template specializations are optional since they are covered by the general template. no?

const float epsilon,
int m,
int n) {
int half_n = n / 2;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if n % 2 != 0?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds you didn't cover this in unittest.

float* dinput;
float* dweight;
float* dbias;
cudaMalloc((void**)&dout, sizeof(float) * m * n);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use torch::tensor to allocate memory

torch::nn::functional::LayerNormFuncOptions({n}).weight(weight).bias(
bias));

half* hout = (half*)malloc(m * n * sizeof(half));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here.

cudaMemcpy(dweight, hweight, sizeof(half) * n, cudaMemcpyHostToDevice);
cudaMemcpy(dbias, hbias, sizeof(half) * n, cudaMemcpyHostToDevice);

llm::kernel::invoke_layernorm_kernel<half>(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just test llm::kernel::layer_norm instead but pass in different length of input to trigger different kernel.

Copy link
Collaborator

@guocuimi guocuimi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for adding the optimization. could you also add benchmark to show the improvements? thanks

@guocuimi guocuimi changed the title [op] layernorm kernel [kernel] added half2 specialization for layernorm kernel Apr 22, 2024
@dongxianzhe dongxianzhe force-pushed the op/layernorm_kernel branch from e18e337 to bc9f7e2 Compare April 27, 2024 06:22
…nitest and just test llm::kernel::layer_norm
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants