[go: nahoru, domu]

Skip to content

Commit

Permalink
Fix a serious bug in strided_contraction for cases where the second/l…
Browse files Browse the repository at this point in the history
…ast tensor disappears
  • Loading branch information
romeric committed Sep 26, 2019
1 parent 70838d2 commit 4ff2ea0
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 9 deletions.
6 changes: 6 additions & 0 deletions meta/einsum_meta.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,9 @@ template<typename T, size_t ...Idx0, size_t ...Idx1, size_t...Rest>
struct is_vectorisable<Index<Idx0...>,Index<Idx1...>,Tensor<T,Rest...>> {
static constexpr size_t fastest_changing_index = get_value<sizeof...(Rest),Rest...>::value;
static constexpr size_t idx[sizeof...(Idx0)] = {Idx0...};
static constexpr bool does_2nd_tensor_disappear = ((int)no_of_unique<Idx0...,Idx1...>::value == (int)sizeof...(Idx0) - (int)sizeof...(Idx1));
static constexpr bool last_index_contracted = contains(idx,get_value<sizeof...(Idx1),Idx1...>::value);
static constexpr bool is_reducible = does_2nd_tensor_disappear && last_index_contracted;
static constexpr bool value = (!last_index_contracted) && (fastest_changing_index % get_vector_size<T,FASTOR_SSE>::size==0);
static constexpr bool sse_vectorisability = (!last_index_contracted) &&
(fastest_changing_index % get_vector_size<T,FASTOR_SSE>::size==0 && fastest_changing_index % get_vector_size<T,FASTOR_AVX>::size!=0);
Expand All @@ -196,7 +198,9 @@ template<size_t ...Idx0, size_t ...Idx1, size_t...Rest>
struct is_vectorisable<Index<Idx0...>,Index<Idx1...>,Tensor<float,Rest...>> {
static constexpr size_t fastest_changing_index = get_value<sizeof...(Rest),Rest...>::value;
static constexpr size_t idx[sizeof...(Idx0)] = {Idx0...};
static constexpr bool does_2nd_tensor_disappear = ((int)no_of_unique<Idx0...,Idx1...>::value == (int)sizeof...(Idx0) - (int)sizeof...(Idx1));
static constexpr bool last_index_contracted = contains(idx,get_value<sizeof...(Idx1),Idx1...>::value);
static constexpr bool is_reducible = does_2nd_tensor_disappear && last_index_contracted;
static constexpr bool value = (!last_index_contracted) && (fastest_changing_index % 4==0);
static constexpr bool sse_vectorisability = (!last_index_contracted) && (fastest_changing_index % 4==0 && fastest_changing_index % 8!=0);
static constexpr bool avx_vectorisability = (!last_index_contracted) && (fastest_changing_index % 4==0 && fastest_changing_index % 8==0);
Expand All @@ -210,7 +214,9 @@ template<size_t ...Idx0, size_t ...Idx1, size_t...Rest>
struct is_vectorisable<Index<Idx0...>,Index<Idx1...>,Tensor<double,Rest...>> {
static constexpr size_t fastest_changing_index = get_value<sizeof...(Rest),Rest...>::value;
static constexpr size_t idx[sizeof...(Idx0)] = {Idx0...};
static constexpr bool does_2nd_tensor_disappear = ((int)no_of_unique<Idx0...,Idx1...>::value == (int)sizeof...(Idx0) - (int)sizeof...(Idx1));
static constexpr bool last_index_contracted = contains(idx,get_value<sizeof...(Idx1),Idx1...>::value);
static constexpr bool is_reducible = does_2nd_tensor_disappear && last_index_contracted;
static constexpr bool value = (!last_index_contracted) && (fastest_changing_index % 2==0);
static constexpr bool sse_vectorisability = (!last_index_contracted) && (fastest_changing_index % 2==0 && fastest_changing_index % 4!=0);
static constexpr bool avx_vectorisability = (!last_index_contracted) && (fastest_changing_index % 2==0 && fastest_changing_index % 4==0);
Expand Down
3 changes: 2 additions & 1 deletion simd_vector/simd_vector_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ struct SIMDVector {
}
template<typename U, typename ... Args>
FASTOR_INLINE void set(U first, Args ... args) {
unused(first);
T arr[Size] = {first,args...};
std::reverse_copy(arr, arr+Size, value);
// Relax this restriction
// static_assert(sizeof...(args)==1,"CANNOT SET VECTOR WITH VALUES DUE TO ABI CONSIDERATION");
}
Expand Down
15 changes: 8 additions & 7 deletions tensor_algebra/einsum.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ auto einsum(const Tensor<T,Rest0...> &a, const Tensor<T,Rest1...> &b)

// Dispatch to the right routine
using vectorisability = is_vectorisable<Index_I,Index_J,Tensor<T,Rest1...>>;
constexpr bool is_reducible = vectorisability::last_index_contracted;
// constexpr bool is_reducible = vectorisability::last_index_contracted;
constexpr bool is_reducible = vectorisability::is_reducible;
if (is_reducible) {
return extractor_reducible_contract<Index_I,Index_J>::contract_impl(a,b);
}
Expand Down Expand Up @@ -94,7 +95,7 @@ auto einsum(const Tensor<T,Rest0...> &a, const Tensor<T,Rest1...> &b, const Tens
// Dispatch to the right routine
using Index0 = typename concat_<Index_I,Index_J>::type;
using vectorisability = is_vectorisable<Index0,Index_K,Tensor<T,Rest2...>>;
constexpr bool is_reducible = vectorisability::last_index_contracted;
constexpr bool is_reducible = vectorisability::is_reducible;
if (is_reducible) {
return extractor_strided_contract<Index_I,Index_J,Index_K>::contract_impl(a,b,c);
}
Expand All @@ -117,7 +118,7 @@ auto einsum(const Tensor<T,Rest0...> &a, const Tensor<T,Rest1...> &b, const Tens
// Dispatch to the right routine
using Index0 = typename concat_<Index_I,Index_J,Index_K>::type;
using vectorisability = is_vectorisable<Index0,Index_L,Tensor<T,Rest3...>>;
constexpr bool is_reducible = vectorisability::last_index_contracted;
constexpr bool is_reducible = vectorisability::is_reducible;
if (is_reducible) {
return extractor_strided_contract_4<Index_I,Index_J,Index_K,Index_L>::contract_impl(a,b,c,d);
}
Expand All @@ -142,7 +143,7 @@ auto einsum(const Tensor<T,Rest0...> &a, const Tensor<T,Rest1...> &b,
// Dispatch to the right routine
using Index0 = typename concat_<Index_I,Index_J,Index_K,Index_L>::type;
using vectorisability = is_vectorisable<Index0,Index_M,Tensor<T,Rest4...>>;
constexpr bool is_reducible = vectorisability::last_index_contracted;
constexpr bool is_reducible = vectorisability::is_reducible;
if (is_reducible) {
return extractor_strided_contract_5<Index_I,Index_J,Index_K,Index_L,Index_M>::contract_impl(a,b,c,d,e);
}
Expand All @@ -166,7 +167,7 @@ auto einsum(const Tensor<T,Rest0...> &a, const Tensor<T,Rest1...> &b,
// Dispatch to the right routine
using Index0 = typename concat_<Index_I,Index_J,Index_K,Index_L,Index_M>::type;
using vectorisability = is_vectorisable<Index0,Index_N,Tensor<T,Rest5...>>;
constexpr bool is_reducible = vectorisability::last_index_contracted;
constexpr bool is_reducible = vectorisability::is_reducible;
if (is_reducible) {
return extractor_strided_contract_6<Index_I,Index_J,Index_K,Index_L,Index_M,Index_N>::contract_impl(a,b,c,d,e,f);
}
Expand All @@ -192,7 +193,7 @@ auto einsum(const Tensor<T,Rest0...> &a, const Tensor<T,Rest1...> &b,
// Dispatch to the right routine
using Index0 = typename concat_<Index_I,Index_J,Index_K,Index_L,Index_M,Index_N>::type;
using vectorisability = is_vectorisable<Index0,Index_O,Tensor<T,Rest6...>>;
constexpr bool is_reducible = vectorisability::last_index_contracted;
constexpr bool is_reducible = vectorisability::is_reducible;
if (is_reducible) {
return extractor_strided_contract_7<Index_I,Index_J,Index_K,Index_L,Index_M,Index_N,Index_O>::contract_impl(a,b,c,d,e,f,g);
}
Expand All @@ -219,7 +220,7 @@ auto einsum(const Tensor<T,Rest0...> &a, const Tensor<T,Rest1...> &b,
// Dispatch to the right routine
using Index0 = typename concat_<Index_I,Index_J,Index_K,Index_L,Index_M,Index_N,Index_O>::type;
using vectorisability = is_vectorisable<Index0,Index_P,Tensor<T,Rest7...>>;
constexpr bool is_reducible = vectorisability::last_index_contracted;
constexpr bool is_reducible = vectorisability::is_reducible;
if (is_reducible) {
return extractor_strided_contract_8<Index_I,Index_J,Index_K,Index_L,Index_M,Index_N,Index_O,Index_P>::contract_impl(a,b,c,d,e,f,g,h);
}
Expand Down
2 changes: 1 addition & 1 deletion tensor_algebra/strided_contraction.h
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ struct extractor_reducible_contract<Index<Idx0...>, Index<Idx1...>> {

template<class Index_I, class Index_J,
typename T, size_t ... Rest0, size_t ... Rest1>
auto strided_contraction(const Tensor<T,Rest0...> &a, const Tensor<T,Rest1...> &b)
FASTOR_INLINE auto strided_contraction(const Tensor<T,Rest0...> &a, const Tensor<T,Rest1...> &b)
-> decltype(extractor_reducible_contract<Index_I,Index_J>::contract_impl(a,b)) {
return extractor_reducible_contract<Index_I,Index_J>::contract_impl(a,b);
}
Expand Down
25 changes: 25 additions & 0 deletions tests/test_einsum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,31 @@ void run() {
assert((einsum<Index<j>,Index<i,j>>(cs,As)).sum() - 42. < Tol);
}

{
// Test strided_contraction when second tensor disappears
Tensor<T,4,4,4> a; a.iota(1);
Tensor<T,4,4> b; b.iota(1);

Tensor<T,4> c1 = einsum<Index<i,j,k>,Index<j,k> >(a,b);
Tensor<T,4> c2 = einsum<Index<i,j,k>,Index<i,k> >(a,b);
Tensor<T,4> c3 = einsum<Index<i,j,k>,Index<i,j> >(a,b);

assert (abs(c1(0) - 1496.) < Tol);
assert (abs(c1(1) - 3672.) < Tol);
assert (abs(c1(2) - 5848.) < Tol);
assert (abs(c1(3) - 8024.) < Tol);

assert (abs(c2(0) - 4904.) < Tol);
assert (abs(c2(1) - 5448.) < Tol);
assert (abs(c2(2) - 5992.) < Tol);
assert (abs(c2(3) - 6536.) < Tol);

assert (abs(c3(0) - 5576.) < Tol);
assert (abs(c3(1) - 5712.) < Tol);
assert (abs(c3(2) - 5848.) < Tol);
assert (abs(c3(3) - 5984.) < Tol);
}

print(FGRN(BOLD("All tests passed successfully")));
}

Expand Down

0 comments on commit 4ff2ea0

Please sign in to comment.