Skip to content

3 dimensional arrays don't do dot product of sub-matrices correctly #212

@mihbor

Description

@mihbor

3 dimensional arrays don't do dot product correctly (unless they have only one top-level element).
Reproducer:

import org.jetbrains.kotlinx.multik.api.linalg.dot
import org.jetbrains.kotlinx.multik.api.mk
import org.jetbrains.kotlinx.multik.api.ndarray
import org.jetbrains.kotlinx.multik.ndarray.data.get

fun main() {

  val a = mk.ndarray(
    mk[
      mk[
        mk[1.0],
        mk[0.0],
      ],
    ],
  )
  val b = mk.ndarray(
    mk[
      mk[
        mk[1.0],
        mk[0.0],
      ],
      mk[
        mk[1.0],
        mk[0.0],
      ],
    ],
  )
  println("a[0] dot a[0].T: (CORRECT)")
  println(a[0] dot a[0].transpose())

  println("\nb[0] dot b[0].T: (WRONG)")
  println(b[0] dot b[0].transpose())

  println("\nb[1] dot b[1].T: (WRONG)")
  println(b[1] dot b[1].transpose())
}

prints

a[0] dot a[0].T: (CORRECT)
[[1.0, 0.0],
[0.0, 0.0]]

b[0] dot b[0].T: (WRONG)
[[0.0, 0.0],
[0.0, 0.0]]

b[1] dot b[1].T: (WRONG)
[[0.0, 0.0],
[0.0, 0.0]]

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingnativeAn issue/PR related to Native

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions