[go: nahoru, domu]

Skip to content

Commit

Permalink
Fix bug in tensor matmul for matrix-vector case. einsum for matrix-ve…
Browse files Browse the repository at this point in the history
…ctor also dispatches to matmul now which is much faster
  • Loading branch information
romeric committed Apr 23, 2019
1 parent 8f4c6ae commit 899c6c0
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 4 deletions.
4 changes: 2 additions & 2 deletions tensor/TensorFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,12 @@ FASTOR_INLINE Tensor<T,I,K> matmul(const Tensor<T,I,J> &a, const Tensor<T,J,K> &
}

template<typename T, size_t I, size_t J>
FASTOR_INLINE Tensor<T,J> matmul(const Tensor<T,I,J> &a, const Tensor<T,J> &b) {
FASTOR_INLINE Tensor<T,I> matmul(const Tensor<T,I,J> &a, const Tensor<T,J> &b) {
// Hack clang to get around alignment
#if defined(__llvm__) || defined(__clang__)
unused(a);
#endif
Tensor<T,J> out;
Tensor<T,I> out;
_matmul<T,I,J,1>(a.data(),b.data(),out.data());
return out;
}
Expand Down
33 changes: 31 additions & 2 deletions tensor_algebra/einsum.h
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,9 @@ einsum(const Tensor<T,I,J> & a, const Tensor<T,K,L> &b) {
}


// matmul dispatcher for 2nd order tensors
// matmul dispatcher for 2nd order tensors (matrix-matrix)
// also includes matrix-vector and vector-matrix when vector is of size
// nx1 or 1xn
template<class Ind0, class Ind1,
typename T, size_t I, size_t J, size_t K,
typename std::enable_if<Ind0::NoIndices==2 && Ind1::NoIndices==2 &&
Expand All @@ -454,13 +456,40 @@ template<class Ind0, class Ind1,
Ind0::_IndexHolder[0] != Ind1::_IndexHolder[1],bool>::type = 0>
FASTOR_INLINE Tensor<T,I,K>
einsum(const Tensor<T,I,J> &a, const Tensor<T,J,K> &b) {

Tensor<T,I,K> out;
_matmul<T,I,J,K>(a.data(),b.data(),out.data());
return out;
}


// matmul dispatcher for matrix-vector
template<class Ind0, class Ind1,
typename T, size_t I, size_t J,
typename std::enable_if<Ind0::NoIndices==2 && Ind1::NoIndices==1 &&
Ind0::_IndexHolder[1] == Ind1::_IndexHolder[0] &&
Ind0::_IndexHolder[1] != Ind0::_IndexHolder[0],bool>::type = 0>
FASTOR_INLINE Tensor<T,I>
einsum(const Tensor<T,I,J> &a, const Tensor<T,J> &b) {
Tensor<T,I> out;
_matmul<T,I,J,1>(a.data(),b.data(),out.data());
return out;
}


// matmul dispatcher for vector-matrix
template<class Ind0, class Ind1,
typename T, size_t I, size_t J,
typename std::enable_if<Ind0::NoIndices==2 && Ind1::NoIndices==1 &&
Ind0::_IndexHolder[1] == Ind1::_IndexHolder[0] &&
Ind0::_IndexHolder[1] != Ind0::_IndexHolder[0],bool>::type = 0>
FASTOR_INLINE Tensor<T,J>
einsum(const Tensor<T,I> &a, const Tensor<T,I,J> &b) {
Tensor<T,J> out;
_matmul<T,I,J,1>(a.data(),b.data(),out.data());
return out;
}


// The following two overloads are provided for an external use case
// A_ijk*B_kl
template<class Ind0, class Ind1,
Expand Down

0 comments on commit 899c6c0

Please sign in to comment.