Skip to content

arbitrary number of batch dimensions and returned samples #7

@forgi86

Description

@forgi86

Hi,

Thanks for the very nice package, I found it very useful! I added two features in this fork:

https://github.com/forgi86/mixture-density-network

  1. Support for more than one "batch" dimension. For instance, my version also works adding to the example script:
    num_seq = 32
    seq_len = 16
    x = x.reshape(num_seq, seq_len, nx)
    y = y.reshape(num_seq, seq_len, ny)

before training. My real use case is to use the mixture density head on top of a RNN backbone for sequential data (as the variable names in the code snipped above suggest). With this modification I do not have reshape to apply the mixture head in my code. It can be fed with the RNN features directly.

  1. Support for sampling multiple values. Added parameters samples and squeeze to the sample method the MixtureDensityNetwork class. The default settings (samples=1, squeeze=True) are backward compatible. If instead samples > 1, then the second-last dimension of the returned tensor is the sample index.

If you find either of the two changes useful, I can make a PR (they can go independently).

Cheers,
Marco

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions