Skip to content

How to pool and visualize attn_maps #18

@donarni

Description

@donarni

Thank you very much for your effort!

I have a question regarding your attn_maps. As far as I understood,

  1. attn_maps store cross attentions values for each timestep as list.
  2. At each timestep attn_maps contain a dict, where each key is a cross attention layer.
  3. Each value is of shape prompt x head x height x width x tokens (2 because conditonal and unconditonal pass, 77 because of padding)

Here is your exemplary output for SD2.1

down_blocks.0.attentions.0.transformer_blocks.0.attn2 torch.Size([2, 5, 96, 96, 77])
down_blocks.0.attentions.1.transformer_blocks.0.attn2 torch.Size([2, 5, 96, 96, 77])
down_blocks.1.attentions.0.transformer_blocks.0.attn2 torch.Size([2, 10, 48, 48, 77])
down_blocks.1.attentions.1.transformer_blocks.0.attn2 torch.Size([2, 10, 48, 48, 77])
down_blocks.2.attentions.0.transformer_blocks.0.attn2 torch.Size([2, 20, 24, 24, 77])
down_blocks.2.attentions.1.transformer_blocks.0.attn2 torch.Size([2, 20, 24, 24, 77])
mid_block.attentions.0.transformer_blocks.0.attn2 torch.Size([2, 20, 12, 12, 77])
up_blocks.1.attentions.0.transformer_blocks.0.attn2 torch.Size([2, 20, 24, 24, 77])
up_blocks.1.attentions.1.transformer_blocks.0.attn2 torch.Size([2, 20, 24, 24, 77])
up_blocks.1.attentions.2.transformer_blocks.0.attn2 torch.Size([2, 20, 24, 24, 77])
up_blocks.2.attentions.0.transformer_blocks.0.attn2 torch.Size([2, 10, 48, 48, 77])
up_blocks.2.attentions.1.transformer_blocks.0.attn2 torch.Size([2, 10, 48, 48, 77])
up_blocks.2.attentions.2.transformer_blocks.0.attn2 torch.Size([2, 10, 48, 48, 77])
up_blocks.3.attentions.0.transformer_blocks.0.attn2 torch.Size([2, 5, 96, 96, 77])
up_blocks.3.attentions.1.transformer_blocks.0.attn2 torch.Size([2, 5, 96, 96, 77])
up_blocks.3.attentions.2.transformer_blocks.0.attn2 torch.Size([2, 5, 96, 96, 77])

Could you tell me how you created your visualizations and across which dimensions you pooled, preferably in line with https://github.com/castorini/daam. Do you use bicubic interpolation for middle layers?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions