[go: nahoru, domu]

Skip to content

Commit

Permalink
Fix a bug in matrix-vector einsum
Browse files Browse the repository at this point in the history
  • Loading branch information
romeric committed Sep 24, 2019
1 parent ec82d56 commit 8241ac8
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 46 deletions.
2 changes: 2 additions & 0 deletions dd/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
all:
$(CXX) main.cpp -o main -I../ -O3 -mavx
Binary file added dd/main
Binary file not shown.
22 changes: 22 additions & 0 deletions dd/main.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include <Fastor.h>

using namespace Fastor;


int main()
{

enum{i,j,k};
Tensor<double, 2, 2> a = {{1, 2}, {3, 4}};
Tensor<double, 2> w = {1, 1};
// Tensor<double,2> e3 = einsum<Index<i, j>, Index <i> >(a, w);
// Tensor<double,2> e3 = einsum<Index<i, j>, Index <j> >(a, w);

// Tensor<double,2> e3 = einsum<Index<i>, Index <i,j> >(w, a);
Tensor<double,2> e3 = einsum<Index<j>, Index <i,j> >(w, a);


print(a,w,e3);

return 0;
}
122 changes: 76 additions & 46 deletions tensor_algebra/einsum.h
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,82 @@ auto einsum(const Tensor<T,Rest0...> &a, const Tensor<T,Rest1...> &b,
#endif



// matmul dispatcher for 2nd order tensors (matrix-matrix)
// also includes matrix-vector and vector-matrix when vector is of size
// nx1 or 1xn
template<class Ind0, class Ind1,
typename T, size_t I, size_t J, size_t K,
typename std::enable_if<Ind0::NoIndices==2 && Ind1::NoIndices==2 &&
Ind0::_IndexHolder[1] == Ind1::_IndexHolder[0] &&
Ind0::_IndexHolder[1] != Ind0::_IndexHolder[0] &&
Ind0::_IndexHolder[1] != Ind1::_IndexHolder[1] &&
Ind0::_IndexHolder[0] != Ind1::_IndexHolder[1],bool>::type = 0>
FASTOR_INLINE Tensor<T,I,K>
einsum(const Tensor<T,I,J> &a, const Tensor<T,J,K> &b) {
Tensor<T,I,K> out;
_matmul<T,I,J,K>(a.data(),b.data(),out.data());
return out;
}


// matmul dispatcher for matrix-vector
template<class Ind0, class Ind1,
typename T, size_t I, size_t J,
typename std::enable_if<Ind0::NoIndices==2 && Ind1::NoIndices==1 &&
Ind0::_IndexHolder[1] == Ind1::_IndexHolder[0] &&
Ind0::_IndexHolder[0] != Ind1::_IndexHolder[0]
,bool>::type = 0>
FASTOR_INLINE Tensor<T,I>
einsum(const Tensor<T,I,J> &a, const Tensor<T,J> &b) {
Tensor<T,I> out;
_matmul<T,I,J,1>(a.data(),b.data(),out.data());
return out;
}

// matmul dispatcher for matrix-vector
template<class Ind0, class Ind1,
typename T, size_t I, size_t J,
typename std::enable_if<Ind0::NoIndices==2 && Ind1::NoIndices==1 &&
Ind0::_IndexHolder[0] == Ind1::_IndexHolder[0] &&
Ind0::_IndexHolder[1] != Ind1::_IndexHolder[0],bool>::type = 0>
FASTOR_INLINE Tensor<T,J>
einsum(const Tensor<T,I,J> &a, const Tensor<T,I> &b) {
Tensor<T,J> out;
_matmul<T,1,I,J>(b.data(),a.data(),out.data());
return out;
}


// matmul dispatcher for vector-matrix
template<class Ind0, class Ind1,
typename T, size_t I, size_t J,
typename std::enable_if<Ind1::NoIndices==2 && Ind0::NoIndices==1 &&
Ind1::_IndexHolder[0] == Ind0::_IndexHolder[0] &&
Ind1::_IndexHolder[1] != Ind0::_IndexHolder[0],bool>::type = 0>
FASTOR_INLINE Tensor<T,J>
einsum(const Tensor<T,I> &a, const Tensor<T,I,J> &b) {
Tensor<T,J> out;
_matmul<T,1,I,J>(a.data(),b.data(),out.data());
return out;
}


// matmul dispatcher for vector-matrix
template<class Ind0, class Ind1,
typename T, size_t I, size_t J,
typename std::enable_if<Ind1::NoIndices==2 && Ind0::NoIndices==1 &&
Ind1::_IndexHolder[1] == Ind0::_IndexHolder[0] &&
Ind1::_IndexHolder[0] != Ind0::_IndexHolder[0],bool>::type = 0>
FASTOR_INLINE Tensor<T,I>
einsum(const Tensor<T,J> &a, const Tensor<T,I,J> &b) {
Tensor<T,I> out;
_matmul<T,I,J,1>(b.data(),a.data(),out.data());
return out;
}



#ifdef __AVX__

// Specific overloads
Expand Down Expand Up @@ -444,52 +520,6 @@ einsum(const Tensor<T,I,J> & a, const Tensor<T,K,L> &b) {
}


// matmul dispatcher for 2nd order tensors (matrix-matrix)
// also includes matrix-vector and vector-matrix when vector is of size
// nx1 or 1xn
template<class Ind0, class Ind1,
typename T, size_t I, size_t J, size_t K,
typename std::enable_if<Ind0::NoIndices==2 && Ind1::NoIndices==2 &&
Ind0::_IndexHolder[1] == Ind1::_IndexHolder[0] &&
Ind0::_IndexHolder[1] != Ind0::_IndexHolder[0] &&
Ind0::_IndexHolder[1] != Ind1::_IndexHolder[1] &&
Ind0::_IndexHolder[0] != Ind1::_IndexHolder[1],bool>::type = 0>
FASTOR_INLINE Tensor<T,I,K>
einsum(const Tensor<T,I,J> &a, const Tensor<T,J,K> &b) {
Tensor<T,I,K> out;
_matmul<T,I,J,K>(a.data(),b.data(),out.data());
return out;
}


// matmul dispatcher for matrix-vector
template<class Ind0, class Ind1,
typename T, size_t I, size_t J,
typename std::enable_if<Ind0::NoIndices==2 && Ind1::NoIndices==1 &&
Ind0::_IndexHolder[1] == Ind1::_IndexHolder[0] &&
Ind0::_IndexHolder[1] != Ind0::_IndexHolder[0],bool>::type = 0>
FASTOR_INLINE Tensor<T,I>
einsum(const Tensor<T,I,J> &a, const Tensor<T,J> &b) {
Tensor<T,I> out;
_matmul<T,I,J,1>(a.data(),b.data(),out.data());
return out;
}


// matmul dispatcher for vector-matrix
template<class Ind0, class Ind1,
typename T, size_t I, size_t J,
typename std::enable_if<Ind0::NoIndices==2 && Ind1::NoIndices==1 &&
Ind0::_IndexHolder[1] == Ind1::_IndexHolder[0] &&
Ind0::_IndexHolder[1] != Ind0::_IndexHolder[0],bool>::type = 0>
FASTOR_INLINE Tensor<T,J>
einsum(const Tensor<T,I> &a, const Tensor<T,I,J> &b) {
Tensor<T,J> out;
_matmul<T,I,J,1>(a.data(),b.data(),out.data());
return out;
}


// The following two overloads are provided for an external use case
// A_ijk*B_kl
template<class Ind0, class Ind1,
Expand Down
11 changes: 11 additions & 0 deletions tests/test_einsum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,17 @@ void run() {
assert(abs(As.sum() - Bs4.sum()) < BigTol);
}

{
Tensor<T,3,2> As; As.iota(1);
Tensor<T,3> bs; bs.fill(1);
Tensor<T,2> cs; cs.fill(2);

assert((einsum<Index<i,j>,Index<j>>(As,cs)).sum() - 42. < Tol);
assert((einsum<Index<i,j>,Index<i>>(As,bs)).sum() - 21. < Tol);
assert((einsum<Index<i>,Index<i,j>>(bs,As)).sum() - 21. < Tol);
assert((einsum<Index<j>,Index<i,j>>(cs,As)).sum() - 42. < Tol);
}

print(FGRN(BOLD("All tests passed successfully")));
}

Expand Down

0 comments on commit 8241ac8

Please sign in to comment.