Skip to content

Making LLMAttribute work with BertForMultipleChoice models #1524

@rbelew

Description

@rbelew

🚀 Feature

Allow LLMAttribution goodness to be applied to BERT models for multiple choice tasks

Motivation

following up on suggestions from aobo-y

Pitch

Integrated gradient attribution techniques work over BertForMultipleChoice; it would be great if
FeatureAblation / LLMAttribution did, too.

Alternatives

Two suggestions were made

First approach:

  • code
    fa = FeatureAblation(model) 
    llm_attr = LLMAttribution(fa, tokenizer)

    inp = TextTokenInput(promptTxt, tokenizer)
    
    attributions_fa = llm_attr.attribute(
                          inp,
                          target=targetIdxTensor,
                          additional_forward_args=dict(
                            token_type_ids=tst['token_type_ids'],
                            # position_ids=position_ids, 
                            attention_mask=tst['attention_mask'],
                            )
                          )

  • throws error:

      File ".../captumPerturb_min.py", line 160, in captumPerturbOne
      attributions_fa = llm_attr.attribute(
      ^^^^^^^^^^^^^^^^^^^
      File ".../site-packages/captum/attr/_core/llm_attr.py", line 674, in attribute
      cur_attr = self.attr_method.attribute(
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
      TypeError: captum.attr._core.feature_ablation.FeatureAblation.attribute() got multiple values for keyword argument 'additional_forward_args'
    
  • dropping additional_forward_args parameter gets farther, but
    throws:

      File ".../captumPerturb_min.py", line 160, in captumPerturbOne
      attributions_fa = llm_attr.attribute(
      ^^^^^^^^^^^^^^^^^^^
      File ".../site-packages/captum/attr/_core/llm_attr.py", line 674, in attribute
      cur_attr = self.attr_method.attribute(
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File ".../site-packages/captum/log/dummy_log.py", line 39, in wrapper
      return func(*args, **kwargs)
      ^^^^^^^^^^^^^^^^^^^^^
      File ".../site-packages/captum/attr/_core/feature_ablation.py", line 288, in attribute
      initial_eval: Union[Tensor, Future[Tensor]] = _run_forward(
      ^^^^^^^^^^^^^
      File ".../site-packages/captum/_utils/common.py", line 588, in _run_forward
      output = forward_func(
      ^^^^^^^^^^^^^
      File ".../site-packages/captum/attr/_core/llm_attr.py", line 574, in _forward_func
      model_inputs = prep_inputs_for_generation(  # type: ignore
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File ".../site-packages/transformers/generation/utils.py", line 376, in prepare_inputs_for_generation
      raise NotImplementedError(
      NotImplementedError: A model class needs to define a `prepare_inputs_for_generation` method in order to use `.generate()`.
    
  • looking in _forward_func variables:

self.model.prepare_inputs_for_generation

    <bound method GenerationMixin.prepare_inputs_for_generation of BertForMultipleChoice(
	(bert): BertModel(
	(embeddings): BertEmbeddings(
	(word_embeddings): Embedding(30522, 768, padding_idx=0)
	(position_embeddings): Embedding(512, 768)
	(token_type_embeddings): Embedding(2, 768)
	(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
	(dropout): Dropout(p=0.1, inplace=False)
	)
	(encoder): BertEncoder(
	(layer): ModuleList(
	(0-11): 12 x BertLayer(
	(attention): BertAttention(
	(self): BertSdpaSelfAttention(
	(query): Linear(in_features=768, out_features=768, bias=True)
	(key): Linear(in_features=768, out_features=768, bias=True)
	(value): Linear(in_features=768, out_features=768, bias=True)
	(dropout): Dropout(p=0.1, inplace=False)
	)
	(output): BertSelfOutput(
	(dense): Linear(in_features=768, out_features=768, bias=True)
	(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
	(dropout): Dropout(p=0.1, inplace=False)
	)
	)
	(intermediate): BertIntermediate(
	(dense): Linear(in_features=768, out_features=3072, bias=True)
	(intermediate_act_fn): GELUActivation()
	)
	(output): BertOutput(
	(dense): Linear(in_features=3072, out_features=768, bias=True)
	(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
	(dropout): Dropout(p=0.1, inplace=False)
	)
	)
	)
	)
	(pooler): BertPooler(
	(dense): Linear(in_features=768, out_features=768, bias=True)
	(activation): Tanh()
	)
	)
	(dropout): Dropout(p=0.1, inplace=False)
	(classifier): Linear(in_features=768, out_features=1, bias=True)
	)>
  • model_inp: tensor, torch.Size([1, 112])

  • model_kwargs.keys()

      dict_keys(['attention_mask', 'cache_position', 'use_cache'])
    

Second approach

  • code
    def multChoice_forward(inputs, token_type_ids=None, position_ids=None, attention_mask=None, target=None):
        output = model(inputs, token_type_ids=token_type_ids,
                     position_ids=position_ids, attention_mask=attention_mask, )
        log_probs = torch.log_softmax(output.logits,1)
       # specify which choice's prob
        return log_probs[target]

    fa = FeatureAblation(multChoice_forward) 
    
    attributions_fa = fa.attribute(
                          tst['input_ids'], 
                          additional_forward_args=dict(
                            token_type_ids=tst['token_type_ids'], 
                            attention_mask=tst['attention_mask'], 
                            target=targetIdxTensor
                          )
                        )

  • throws

      File ".../captumPerturb_min.py", line 294, in main
      captumPerturbOne(model,tokenizer,tstDict,tstTarget)
      File ".../captumPerturb_min.py", line 184, in captumPerturbOne
      attributions_fa = fa.attribute(
      ^^^^^^^^^^^^^
      File ".../site-packages/captum/log/dummy_log.py", line 39, in wrapper
      return func(*args, **kwargs)
      ^^^^^^^^^^^^^^^^^^^^^
      File ".../site-packages/captum/attr/_core/feature_ablation.py", line 288, in attribute
      initial_eval: Union[Tensor, Future[Tensor]] = _run_forward(
      ^^^^^^^^^^^^^
      File ".../site-packages/captum/_utils/common.py", line 588, in _run_forward
      output = forward_func(
      ^^^^^^^^^^^^^
      File ".../captumPerturb_min.py", line 175, in multChoice_forward
      output = model(inputs, token_type_ids=token_type_ids,
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File ".../site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File ".../site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
      return forward_call(*args, **kwargs)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File ".../site-packages/transformers/models/bert/modeling_bert.py", line 1799, in forward
      token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
      ^^^^^^^^^^^^^^^^^^^
      AttributeError: 'dict' object has no attribute 'view'
    

This is too far into Transformer API-land for me to follow.

Additional context

Additional details in original issue #1523

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions