@@ -249,25 +249,26 @@ def influence( # type: ignore[override]
249249 ) -> Union [Tensor , KMostInfluentialResults ]:
250250 r"""
251251 This is the key method of this class, and can be run in 3 different modes,
252- where the mode that is run depends on the arguments passed to this method.
252+ where the mode that is run depends on the arguments passed to this method:
253+
253254 - self influence mode: This mode is used if `inputs` is None. This mode
254- computes the self influence scores for every example in
255- the training dataset `influence_src_dataset`.
255+ computes the self influence scores for every example in
256+ the training dataset `influence_src_dataset`.
256257 - influence score mode: This mode is used if `inputs` is not None, and `k` is
257- None. This mode computes the influence score of every example in
258- training dataset `influence_src_dataset` on every example in the test
259- batch represented by `inputs` and `targets`.
258+ None. This mode computes the influence score of every example in
259+ training dataset `influence_src_dataset` on every example in the test
260+ batch represented by `inputs` and `targets`.
260261 - k-most influential mode: This mode is used if `inputs` is not None, and
261- `k` is not None, and an int. This mode computes the proponents or
262- opponents of every example in the test batch represented by `inputs`
263- and `targets`. In particular, for each test example in the test batch,
264- this mode computes its proponents (resp. opponents), which are the
265- indices in the training dataset `influence_src_dataset` of the training
266- examples with the `k` highest (resp. lowest) influence scores on the
267- test example. Proponents are computed if `proponents` is True.
268- Otherwise, opponents are computed. For each test example, this method
269- also returns the actual influence score of each proponent (resp.
270- opponent) on the test example.
262+ `k` is not None, and an int. This mode computes the proponents or
263+ opponents of every example in the test batch represented by `inputs`
264+ and `targets`. In particular, for each test example in the test batch,
265+ this mode computes its proponents (resp. opponents), which are the
266+ indices in the training dataset `influence_src_dataset` of the training
267+ examples with the `k` highest (resp. lowest) influence scores on the
268+ test example. Proponents are computed if `proponents` is True.
269+ Otherwise, opponents are computed. For each test example, this method
270+ also returns the actual influence score of each proponent (resp.
271+ opponent) on the test example.
271272
272273 Args:
273274 inputs (Any, optional): If not provided or `None`, the self influence mode
@@ -300,33 +301,34 @@ def influence( # type: ignore[override]
300301
301302 Returns:
302303 The return value of this method depends on which mode is run.
304+
303305 - self influence mode: if this mode is run (`inputs` is None), returns a 1D
304- tensor of self influence scores over training dataset
305- `influence_src_dataset`. The length of this tensor is the number of
306- examples in `influence_src_dataset`, regardless of whether it is a
307- Dataset or DataLoader.
306+ tensor of self influence scores over training dataset
307+ `influence_src_dataset`. The length of this tensor is the number of
308+ examples in `influence_src_dataset`, regardless of whether it is a
309+ Dataset or DataLoader.
308310 - influence score mode: if this mode is run (`inputs is not None, `k` is
309- None), returns a 2D tensor `influence_scores` of shape
310- `(input_size, influence_src_dataset_size)`, where `input_size` is
311- the number of examples in the test batch, and
312- `influence_src_dataset_size` is the number of examples in
313- training dataset `influence_src_dataset`. In other words,
314- `influence_scores[i][j]` is the influence score of the `j`-th
315- example in `influence_src_dataset` on the `i`-th example in the
316- test batch.
317- - k-most influential mode: if this mode is run (`inputs` is not None,
318- `k` is an int), returns a namedtuple `(indices, influence_scores)`.
319- `indices` is a 2D tensor of shape `(input_size, k)`, where
320- `input_size` is the number of examples in the test batch. If
321- computing proponents (resp. opponents), `indices[i][j]` is the
322- index in training dataset `influence_src_dataset` of the example
323- with the `j`-th highest (resp. lowest) influence score (out of the
324- examples in `influence_src_dataset`) on the `i`-th example in the
325- test batch. `influence_scores` contains the corresponding influence
326- scores. In particular, `influence_scores[i][j]` is the influence
327- score of example `indices[i][j]` in `influence_src_dataset` on
328- example `i` in the test batch represented by `inputs` and
329- `targets`.
311+ None), returns a 2D tensor `influence_scores` of shape
312+ `(input_size, influence_src_dataset_size)`, where `input_size` is
313+ the number of examples in the test batch, and
314+ `influence_src_dataset_size` is the number of examples in
315+ training dataset `influence_src_dataset`. In other words,
316+ `influence_scores[i][j]` is the influence score of the `j`-th
317+ example in `influence_src_dataset` on the `i`-th example in the
318+ test batch.
319+ - k-most influential mode: if this mode is run (`inputs` is not None,
320+ `k` is an int), returns a namedtuple `(indices, influence_scores)`.
321+ `indices` is a 2D tensor of shape `(input_size, k)`, where
322+ `input_size` is the number of examples in the test batch. If
323+ computing proponents (resp. opponents), `indices[i][j]` is the
324+ index in training dataset `influence_src_dataset` of the example
325+ with the `j`-th highest (resp. lowest) influence score (out of the
326+ examples in `influence_src_dataset`) on the `i`-th example in the
327+ test batch. `influence_scores` contains the corresponding influence
328+ scores. In particular, `influence_scores[i][j]` is the influence
329+ score of example `indices[i][j]` in `influence_src_dataset` on
330+ example `i` in the test batch represented by `inputs` and
331+ `targets`.
330332 """
331333 _inputs = _format_inputs (inputs , unpack_inputs )
332334
0 commit comments