fixed dot len 4

This commit is contained in:
Jorijn van der Graaf 2026-03-27 05:54:37 +01:00
commit bd9230e07e
2 changed files with 262 additions and 82 deletions

View file

@ -893,82 +893,58 @@ namespace Crafter {
VectorF16<Len, Packing> E,
VectorF16<Len, Packing> G
) requires(Len == 4 && Packing*Len == Alignment) {
constexpr std::array<std::uint8_t, Alignment*2> shuffleMaskA = GetShuffleMaskEpi8<{{0,0,0,0}}>();
constexpr std::array<std::uint8_t, Alignment*2> shuffleMaskC = GetShuffleMaskEpi8<{{1,1,1,1}}>();
constexpr std::array<std::uint8_t, Alignment*2> shuffleMaskE = GetShuffleMaskEpi8<{{2,2,2,2}}>();
constexpr std::array<std::uint8_t, Alignment*2> shuffleMaskG = GetShuffleMaskEpi8<{{3,3,3,3}}>();
if constexpr(std::is_same_v<VectorType, __m128h>) {
VectorF16<Len, Packing> lenght = Length(A, C, E, G);
VectorF16<1, 8> lenght = LengthNoShuffle(A, E, C, G);
constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1};
__m128h one = _mm_loadu_ph(oneArr);
__m128h fLenght = _mm_div_ph(one, lenght.v);
VectorF16<8, 1> fLenght(_mm_div_ph(one, lenght.v));
__m128i shuffleVecA = _mm_loadu_epi8(shuffleMaskA.data());
__m128h fLenghtA = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecA));
__m128i shuffleVecC = _mm_loadu_epi8(shuffleMaskC.data());
__m128h fLenghtC = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecC));
__m128i shuffleVecE = _mm_loadu_epi8(shuffleMaskE.data());
__m128h fLenghtE = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecE));
__m128i shuffleVecG = _mm_loadu_epi8(shuffleMaskG.data());
__m128h fLenghtG = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecG));
VectorF16<8, 1> fLenghtA = fLenght.template Shuffle<{{0,0,0,0,1,1,1,1}}>();
VectorF16<8, 1> fLenghtC = fLenght.template Shuffle<{{2,2,2,2,3,3,3,3}}>();
VectorF16<8, 1> fLenghtE = fLenght.template Shuffle<{{4,4,4,4,5,5,5,5}}>();
VectorF16<8, 1> fLenghtG = fLenght.template Shuffle<{{6,6,6,6,7,7,7,7}}>();
return {
_mm_mul_ph(A.v, fLenghtA),
_mm_mul_ph(C.v, fLenghtC),
_mm_mul_ph(E.v, fLenghtE),
_mm_mul_ph(G.v, fLenghtG),
_mm_mul_ph(A.v, fLenghtA.v),
_mm_mul_ph(C.v, fLenghtC.v),
_mm_mul_ph(E.v, fLenghtE.v),
_mm_mul_ph(G.v, fLenghtG.v)
};
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
VectorF16<Len, Packing> lenght = Length(A, C, E, G);
VectorF16<1, 16> lenght = LengthNoShuffle(A, E, C, G);
constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
__m256h one = _mm256_loadu_ph(oneArr);
__m256h fLenght = _mm256_div_ph(one, lenght.v);
VectorF16<16, 1> fLenght(_mm256_div_ph(one, lenght.v));
__m256i shuffleVecA = _mm256_loadu_epi8(shuffleMaskA.data());
__m256h fLenghtA = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecA));
__m256i shuffleVecC = _mm256_loadu_epi8(shuffleMaskC.data());
__m256h fLenghtC = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecC));
__m256i shuffleVecE = _mm256_loadu_epi8(shuffleMaskE.data());
__m256h fLenghtE = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecE));
__m256i shuffleVecG = _mm256_loadu_epi8(shuffleMaskG.data());
__m256h fLenghtG = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecG));
VectorF16<16, 1> fLenghtA = fLenght.template Shuffle<{{0,0,0,0,1,1,1,1,8,8,8,8,9,9,9,9}}>();
VectorF16<16, 1> fLenghtC = fLenght.template Shuffle<{{2,2,2,2,3,3,3,3,10,10,10,10,11,11,11,11}}>();
VectorF16<16, 1> fLenghtE = fLenght.template Shuffle<{{4,4,4,4,5,5,5,5,12,12,12,12,13,13,13,13}}>();
VectorF16<16, 1> fLenghtG = fLenght.template Shuffle<{{6,6,6,6,7,7,7,7,14,14,14,14,15,15,15,15}}>();
return {
_mm256_mul_ph(A.v, fLenghtA),
_mm256_mul_ph(C.v, fLenghtC),
_mm256_mul_ph(E.v, fLenghtE),
_mm256_mul_ph(G.v, fLenghtG),
_mm256_mul_ph(A.v, fLenghtA.v),
_mm256_mul_ph(C.v, fLenghtC.v),
_mm256_mul_ph(E.v, fLenghtE.v),
_mm256_mul_ph(G.v, fLenghtG.v)
};
} else {
VectorF16<Len, Packing> lenght = Length(A, C, E, G);
VectorF16<1, 32> lenght = LengthNoShuffle(A, E, C, G);
constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
__m512h one = _mm512_loadu_ph(oneArr);
__m512h fLenght = _mm512_div_ph(one, lenght.v);
VectorF16<32, 1> fLenght(_mm512_div_ph(one, lenght.v));
VectorF16<32, 1> fLenght2(lenght.v);
__m512i shuffleVecA = _mm512_loadu_epi8(shuffleMaskA.data());
__m512h fLenghtA = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecA));
__m512i shuffleVecC = _mm512_loadu_epi8(shuffleMaskC.data());
__m512h fLenghtC = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecC));
__m512i shuffleVecE = _mm512_loadu_epi8(shuffleMaskE.data());
__m512h fLenghtE = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecE));
__m512i shuffleVecG = _mm512_loadu_epi8(shuffleMaskG.data());
__m512h fLenghtG = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecG));
VectorF16<32, 1> fLenghtA = fLenght.template Shuffle<{{0,0,0,0,1,1,1,1,8,8,8,8,9,9,9,9,16,16,16,16,17,17,17,17,24,24,24,24,25,25,25,25}}>();
VectorF16<32, 1> fLenghtC = fLenght.template Shuffle<{{2,2,2,2,3,3,3,3,10,10,10,10,11,11,11,11,18,18,18,18,19,19,19,19,26,26,26,26,27,27,27,27}}>();
VectorF16<32, 1> fLenghtE = fLenght.template Shuffle<{{4,4,4,4,5,5,5,5,12,12,12,12,13,13,13,13,20,20,20,20,21,21,21,21,28,28,28,28,29,29,29,29}}>();
VectorF16<32, 1> fLenghtG = fLenght.template Shuffle<{{6,6,6,6,7,7,7,7,14,14,14,14,15,15,15,15,22,22,22,22,23,23,23,23,30,30,30,30,31,31,31,31}}>();
return {
VectorF16<Len, Packing>(_mm512_mul_ph(A.v, fLenghtA)),
VectorF16<Len, Packing>(_mm512_mul_ph(C.v, fLenghtC)),
VectorF16<Len, Packing>(_mm512_mul_ph(E.v, fLenghtE)),
VectorF16<Len, Packing>(_mm512_mul_ph(G.v, fLenghtG)),
VectorF16<Len, Packing>(_mm512_mul_ph(A.v, fLenghtA.v)),
VectorF16<Len, Packing>(_mm512_mul_ph(C.v, fLenghtC.v)),
VectorF16<Len, Packing>(_mm512_mul_ph(E.v, fLenghtE.v)),
VectorF16<Len, Packing>(_mm512_mul_ph(G.v, fLenghtG.v)),
};
}
}
@ -978,7 +954,7 @@ namespace Crafter {
VectorF16<Len, Packing> E
) requires(Len == 2 && Packing*Len == Alignment) {
if constexpr(std::is_same_v<VectorType, __m128h>) {
VectorF16<1, 8> lenght = Length(A, E);
VectorF16<1, 8> lenght = LengthNoShuffle(A, E);
constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1};
__m128h one = _mm_loadu_ph(oneArr);
VectorF16<8, 1> fLenght(_mm_div_ph(one, lenght.v));
@ -991,26 +967,26 @@ namespace Crafter {
_mm_mul_ph(E.v, fLenghtE.v),
};
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
VectorF16<1, 16> lenght = Length(A, E);
VectorF16<1, 16> lenght = LengthNoShuffle(A, E);
constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
__m256h one = _mm256_loadu_ph(oneArr);
VectorF16<16, 1> fLenght(_mm256_div_ph(one, lenght.v));
VectorF16<16, 1> fLenghtA = fLenght.template Shuffle<{{0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7}}>();
VectorF16<16, 1> fLenghtE = fLenght.template Shuffle<{{8,8,9,9,10,10,11,11,12,12,13,13,14,14,15,15}}>();
VectorF16<16, 1> fLenghtA = fLenght.template Shuffle<{{0,0,1,1,2,2,3,3,8,8,9,9,10,10,11,11}}>();
VectorF16<16, 1> fLenghtE = fLenght.template Shuffle<{{4,4,5,5,6,6,7,7,12,12,13,13,14,14,15,15}}>();
return {
_mm256_mul_ph(A.v, fLenghtA.v),
_mm256_mul_ph(E.v, fLenghtE.v),
};
} else {
VectorF16<1, 32> lenght = Length(A, E);
VectorF16<1, 32> lenght = LengthNoShuffle(A, E);
constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
__m512h one = _mm512_loadu_ph(oneArr);
VectorF16<32, 1> fLenght(_mm512_div_ph(one, lenght.v));
VectorF16<32, 1> fLenghtA = fLenght.template Shuffle<{{0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,9,9,10,10,11,11,12,12,13,13,14,14,15,15}}>();
VectorF16<32, 1> fLenghtE = fLenght.template Shuffle<{{16,16,17,17,18,18,19,19,20,20,21,21,22,22,23,23,24,24,25,25,26,26,27,27,28,28,29,29,30,30,31,31}}>();
VectorF16<32, 1> fLenghtA = fLenght.template Shuffle<{{0,0,1,1,2,2,3,3,8,8,9,9,10,10,11,11,16,16,17,17,18,18,19,19,24,24,25,25,26,26,27,27}}>();
VectorF16<32, 1> fLenghtE = fLenght.template Shuffle<{{4,4,5,5,6,6,7,7,12,12,13,13,14,14,15,15,20,20,21,21,22,22,23,23,28,28,29,29,30,30,31,31}}>();
return {
_mm512_mul_ph(A.v, fLenghtA.v),
@ -1108,6 +1084,163 @@ namespace Crafter {
VectorF16<Len, Packing> G0, VectorF16<Len, Packing> G1,
VectorF16<Len, Packing> H0, VectorF16<Len, Packing> H1
) requires(Len == 8 && Packing*Len == Alignment) {
if constexpr(std::is_same_v<VectorType, __m128h>) {
return DotNoShuffle(A0, A1, B0, B1, C0, C1, D0, D1, E0, E1, G0, G1, H0, H1);
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
VectorF16<16, 1> vec(DotNoShuffle(A0, A1, B0, B1, C0, C1, D0, D1, E0, E1, G0, G1, H0, H1).v);
vec = vec.template Shuffle<{{0,1,2,3,8,9,10,11,4,5,6,7,12,13,14,15}}>();
return vec.v;
} else {
VectorF16<32, 1> vec(DotNoShuffle(A0, A1, B0, B1, C0, C1, D0, D1, E0, E1, G0, G1, H0, H1).v);
vec = vec.template Shuffle<{{0,1,2,3,8,9,10,11,16,17,18,19,24,25,26,27,4,5,6,7,12,13,14,15,20,21,22,23,28,29,30,31}}>();
return vec.v;
}
}
constexpr static VectorF16<1, Packing*4> Dot(
VectorF16<Len, Packing> A0, VectorF16<Len, Packing> A1,
VectorF16<Len, Packing> C0, VectorF16<Len, Packing> C1,
VectorF16<Len, Packing> E0, VectorF16<Len, Packing> E1,
VectorF16<Len, Packing> G0, VectorF16<Len, Packing> G1
) requires(Len == 4 && Packing*Len == Alignment) {
if constexpr(std::is_same_v<VectorType, __m128h>) {
return DotNoShuffle(A0, A1, E0, E1, C0, C1, G0, G1);
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
VectorF16<16, 1> vec(DotNoShuffle(A0, A1, C0, C1, E0, E1, G0, G1).v);
vec = vec.template Shuffle<{{
0,1,8,9,
4,5,12,13,
2,3,10,11,
6,7,14,15
}}>();
return vec.v;
} else {
VectorF16<32, 1> vec(DotNoShuffle(A0, A1, C0, C1, E0, E1, G0, G1).v);
vec = vec.template Shuffle<{{
0,1,8,9,
16,17,24,25,
4,5,12,13,
20,21,28,29,
2,3,10,11,
18,19,24,25,
6,7,14,15,
22,23,30,31
}}>();
return vec.v;
}
}
constexpr static VectorF16<1, Packing*2> Dot(
VectorF16<Len, Packing> A0, VectorF16<Len, Packing> A1,
VectorF16<Len, Packing> E0, VectorF16<Len, Packing> E1
) requires(Len == 2 && Packing*Len == Alignment) {
if constexpr(std::is_same_v<VectorType, __m128h>) {
return DotNoShuffle(A0, A1, E0, E1);
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
VectorF16<16, 1> vec(DotNoShuffle(A0, A1, E0, E1).v);
vec = vec.template Shuffle<{{0,1,2,3,8,9,10,11,4,5,6,7,12,13,14,15}}>();
return vec.v;
} else {
VectorF16<32, 1> vec(DotNoShuffle(A0, A1, E0, E1).v);
vec = vec.template Shuffle<{{0,1,2,3,8,9,10,11,16,17,18,19,24,25,26,27,4,5,6,7,12,13,14,15,20,21,22,23,28,29,30,31}}>();
return vec.v;
}
}
private:
constexpr static VectorF16<1, Packing*8> LengthNoShuffle(
VectorF16<Len, Packing> A,
VectorF16<Len, Packing> B,
VectorF16<Len, Packing> C,
VectorF16<Len, Packing> D,
VectorF16<Len, Packing> E,
VectorF16<Len, Packing> F,
VectorF16<Len, Packing> G,
VectorF16<Len, Packing> H
) requires(Len == 8 && Packing*Len == Alignment) {
VectorF16<1, Packing*8> lenghtSq = LengthSqNoShuffle(A, B, C, D, E, F, G, H);
if constexpr(std::is_same_v<VectorType, __m128h>) {
return VectorF16<1, Packing*8>(_mm_sqrt_ph(lenghtSq.v));
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
return VectorF16<1, Packing*8>(_mm256_sqrt_ph(lenghtSq.v));
} else {
return VectorF16<1, Packing*8>(_mm512_sqrt_ph(lenghtSq.v));
}
}
constexpr static VectorF16<1, Packing*4> LengthNoShuffle(
VectorF16<Len, Packing> A,
VectorF16<Len, Packing> C,
VectorF16<Len, Packing> E,
VectorF16<Len, Packing> G
) requires(Len == 4 && Packing*Len == Alignment) {
VectorF16<1, Packing*4> lenghtSq = LengthSqNoShuffle(A, C, E, G);
if constexpr(std::is_same_v<VectorType, __m128h>) {
return VectorF16<1, Packing*4>(_mm_sqrt_ph(lenghtSq.v));
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
return VectorF16<1, Packing*4>(_mm256_sqrt_ph(lenghtSq.v));
} else {
return VectorF16<1, Packing*4>(_mm512_sqrt_ph(lenghtSq.v));
}
}
constexpr static VectorF16<1, Packing*2> LengthNoShuffle(
VectorF16<Len, Packing> A,
VectorF16<Len, Packing> E
) requires(Len == 2 && Packing*Len == Alignment) {
VectorF16<1, Packing*2> lenghtSq = LengthSqNoShuffle(A, E);
if constexpr(std::is_same_v<VectorType, __m128h>) {
return VectorF16<1, Packing*2>(_mm_sqrt_ph(lenghtSq.v));
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
return VectorF16<1, Packing*2>(_mm256_sqrt_ph(lenghtSq.v));
} else {
return VectorF16<1, Packing*2>(_mm512_sqrt_ph(lenghtSq.v));
}
}
constexpr static VectorF16<1, Packing*8> LengthSqNoShuffle(
VectorF16<Len, Packing> A,
VectorF16<Len, Packing> B,
VectorF16<Len, Packing> C,
VectorF16<Len, Packing> D,
VectorF16<Len, Packing> E,
VectorF16<Len, Packing> F,
VectorF16<Len, Packing> G,
VectorF16<Len, Packing> H
) requires(Len == 8 && Packing*Len == Alignment) {
return DotNoShuffle(A, A, B, B, C, C, D, D, E, E, F, F, G, G, H, H);
}
constexpr static VectorF16<1, Packing*4> LengthSqNoShuffle(
VectorF16<Len, Packing> A,
VectorF16<Len, Packing> C,
VectorF16<Len, Packing> E,
VectorF16<Len, Packing> G
) requires(Len == 4 && Packing*Len == Alignment) {
return DotNoShuffle(A, A, C, C, E, E, G, G);
}
constexpr static VectorF16<1, Packing*2> LengthSqNoShuffle(
VectorF16<Len, Packing> A,
VectorF16<Len, Packing> E
) requires(Len == 2 && Packing*Len == Alignment) {
return DotNoShuffle(A, A, E, E);
}
constexpr static VectorF16<1, Packing*8> DotNoShuffle(
VectorF16<Len, Packing> A0, VectorF16<Len, Packing> A1,
VectorF16<Len, Packing> B0, VectorF16<Len, Packing> B1,
VectorF16<Len, Packing> C0, VectorF16<Len, Packing> C1,
VectorF16<Len, Packing> D0, VectorF16<Len, Packing> D1,
VectorF16<Len, Packing> E0, VectorF16<Len, Packing> E1,
VectorF16<Len, Packing> F0, VectorF16<Len, Packing> F1,
VectorF16<Len, Packing> G0, VectorF16<Len, Packing> G1,
VectorF16<Len, Packing> H0, VectorF16<Len, Packing> H1
) requires(Len == 8 && Packing*Len == Alignment) {
if constexpr(std::is_same_v<VectorType, __m128h>) {
__m128h mulA = _mm_mul_ph(A0.v, A1.v);
__m128h mulB = _mm_mul_ph(B0.v, B1.v);
@ -1270,7 +1403,7 @@ namespace Crafter {
}
}
constexpr static VectorF16<1, Packing*4> Dot(
constexpr static VectorF16<1, Packing*4> DotNoShuffle(
VectorF16<Len, Packing> A0, VectorF16<Len, Packing> A1,
VectorF16<Len, Packing> C0, VectorF16<Len, Packing> C1,
VectorF16<Len, Packing> E0, VectorF16<Len, Packing> E1,
@ -1279,20 +1412,22 @@ namespace Crafter {
if constexpr(std::is_same_v<VectorType, __m128h>) {
__m128h mulA = _mm_mul_ph(A0.v, A1.v);
__m128h mulC = _mm_mul_ph(C0.v, C1.v);
__m128i row12Temp1 = _mm_unpacklo_epi16(_mm_castph_si128(mulA), _mm_castph_si128(mulC)); // A1 C1 A2 C2 A3 C3 A4 C4
__m128i row34Temp1 = _mm_unpackhi_epi16(_mm_castph_si128(mulA), _mm_castph_si128(mulC)); // B1 D1 B2 D2 B3 D3 B4 D4
__m128i row1TempTemp1 = row12Temp1;
__m128i row5TempTemp1 = row34Temp1;
__m128h mulE = _mm_mul_ph(E0.v, E1.v);
__m128h mulG = _mm_mul_ph(G0.v, G1.v);
__m128i row12Temp2 = _mm_unpacklo_epi16(_mm_castph_si128(mulE), _mm_castph_si128(mulG)); // E1 G1 E2 G2 E3 G3 E4 G4
__m128i row12Temp2Temp = row12Temp2;
__m128i row34Temp2 = _mm_unpackhi_epi16(_mm_castph_si128(mulE), _mm_castph_si128(mulG)); // F1 H1 F2 H2 F3 H3 F4 H4
row12Temp1 = _mm_unpacklo_epi16(row12Temp1, row12Temp2); // A1 E1 C1 G1 A2 E2 C2 G2
row12Temp2 = _mm_unpacklo_epi16(row34Temp1, row34Temp2); // B1 F1 D1 H1 B2 F2 D2 H2
row34Temp1 = _mm_unpackhi_epi16(row1TempTemp1, row34Temp1); // A3 E3 C3 G3 A4 E4 C4 G4
row34Temp2 = _mm_unpackhi_epi16(row5TempTemp1, row34Temp2); // B3 F3 D3 H3 B4 F4 D4 H4
row34Temp2 = _mm_unpackhi_epi16(row34Temp1, row34Temp2); // B3 F3 D3 H3 B4 F4 D4 H4
row34Temp1 = _mm_unpackhi_epi16(row1TempTemp1, row12Temp2Temp); // A3 E3 C3 G3 A4 E4 C4 G4
__m128h row1 = _mm_castsi128_ph(_mm_unpacklo_epi16(row12Temp1, row12Temp2));// A1 B1 E1 F1 C1 D1 G1 H1
__m128h row2 = _mm_castsi128_ph(_mm_unpackhi_epi16(row12Temp1, row12Temp2));// A2 B2 E2 F2 C2 D2 G2 H2
@ -1307,20 +1442,22 @@ namespace Crafter {
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
__m256h mulA = _mm256_mul_ph(A0.v, A1.v);
__m256h mulC = _mm256_mul_ph(C0.v, C1.v);
__m256i row12Temp1 = _mm256_unpacklo_epi16(_mm256_castph_si256(mulA), _mm256_castph_si256(mulC)); // A1 C1 A2 C2 A3 C3 A4 C4
__m256i row34Temp1 = _mm256_unpackhi_epi16(_mm256_castph_si256(mulA), _mm256_castph_si256(mulC)); // B1 D1 B2 D2 B3 D3 B4 D4
__m256i row1TempTemp1 = row12Temp1;
__m256i row5TempTemp1 = row34Temp1;
__m256h mulE = _mm256_mul_ph(E0.v, E1.v);
__m256h mulG = _mm256_mul_ph(G0.v, G1.v);
__m256i row12Temp2 = _mm256_unpacklo_epi16(_mm256_castph_si256(mulE), _mm256_castph_si256(mulG)); // E1 G1 E2 G2 E3 G3 E4 G4
__m256i row12Temp2Temp = row12Temp2;
__m256i row34Temp2 = _mm256_unpackhi_epi16(_mm256_castph_si256(mulE), _mm256_castph_si256(mulG)); // F1 H1 F2 H2 F3 H3 F4 H4
row12Temp1 = _mm256_unpacklo_epi16(row12Temp1, row12Temp2); // A1 E1 C1 G1 A2 E2 C2 G2
row12Temp2 = _mm256_unpacklo_epi16(row34Temp1, row34Temp2); // B1 F1 D1 H1 B2 F2 D2 H2
row34Temp1 = _mm256_unpackhi_epi16(row1TempTemp1, row34Temp1); // A3 E3 C3 G3 A4 E4 C4 G4
row34Temp2 = _mm256_unpackhi_epi16(row5TempTemp1, row34Temp2); // B3 F3 D3 H3 B4 F4 D4 H4
row34Temp2 = _mm256_unpackhi_epi16(row34Temp1, row34Temp2); // B3 F3 D3 H3 B4 F4 D4 H4
row34Temp1 = _mm256_unpackhi_epi16(row1TempTemp1, row12Temp2Temp); // A3 E3 C3 G3 A4 E4 C4 G4
__m256h row1 = _mm256_castsi256_ph(_mm256_unpacklo_epi16(row12Temp1, row12Temp2));// A1 B1 E1 F1 C1 D1 G1 H1
__m256h row2 = _mm256_castsi256_ph(_mm256_unpackhi_epi16(row12Temp1, row12Temp2));// A2 B2 E2 F2 C2 D2 G2 H2
@ -1335,20 +1472,22 @@ namespace Crafter {
} else {
__m512h mulA = _mm512_mul_ph(A0.v, A1.v);
__m512h mulC = _mm512_mul_ph(C0.v, C1.v);
__m512i row12Temp1 = _mm512_unpacklo_epi16(_mm512_castph_si512(mulA), _mm512_castph_si512(mulC)); // A1 C1 A2 C2 A3 C3 A4 C4
__m512i row34Temp1 = _mm512_unpackhi_epi16(_mm512_castph_si512(mulA), _mm512_castph_si512(mulC)); // B1 D1 B2 D2 B3 D3 B4 D4
__m512i row1TempTemp1 = row12Temp1;
__m512i row5TempTemp1 = row34Temp1;
__m512h mulE = _mm512_mul_ph(E0.v, E1.v);
__m512h mulG = _mm512_mul_ph(G0.v, G1.v);
__m512i row12Temp2 = _mm512_unpacklo_epi16(_mm512_castph_si512(mulE), _mm512_castph_si512(mulG)); // E1 G1 E2 G2 E3 G3 E4 G4
__m512i row12Temp2Temp = row12Temp2;
__m512i row34Temp2 = _mm512_unpackhi_epi16(_mm512_castph_si512(mulE), _mm512_castph_si512(mulG)); // F1 H1 F2 H2 F3 H3 F4 H4
row12Temp1 = _mm512_unpacklo_epi16(row12Temp1, row12Temp2); // A1 E1 C1 G1 A2 E2 C2 G2
row12Temp2 = _mm512_unpacklo_epi16(row34Temp1, row34Temp2); // B1 F1 D1 H1 B2 F2 D2 H2
row34Temp1 = _mm512_unpackhi_epi16(row1TempTemp1, row34Temp1); // A3 E3 C3 G3 A4 E4 C4 G4
row34Temp2 = _mm512_unpackhi_epi16(row5TempTemp1, row34Temp2); // B3 F3 D3 H3 B4 F4 D4 H4
row34Temp2 = _mm512_unpackhi_epi16(row34Temp1, row34Temp2); // B3 F3 D3 H3 B4 F4 D4 H4
row34Temp1 = _mm512_unpackhi_epi16(row1TempTemp1, row12Temp2Temp); // A3 E3 C3 G3 A4 E4 C4 G4
__m512h row1 = _mm512_castsi512_ph(_mm512_unpacklo_epi16(row12Temp1, row12Temp2));// A1 B1 E1 F1 C1 D1 G1 H1
__m512h row2 = _mm512_castsi512_ph(_mm512_unpackhi_epi16(row12Temp1, row12Temp2));// A2 B2 E2 F2 C2 D2 G2 H2
@ -1363,7 +1502,7 @@ namespace Crafter {
}
}
constexpr static VectorF16<1, Packing*2> Dot(
constexpr static VectorF16<1, Packing*2> DotNoShuffle(
VectorF16<Len, Packing> A0, VectorF16<Len, Packing> A1,
VectorF16<Len, Packing> E0, VectorF16<Len, Packing> E1
) requires(Len == 2 && Packing*Len == Alignment) {
@ -1397,10 +1536,7 @@ namespace Crafter {
__m256h row2 = _mm256_castsi256_ph(_mm256_unpackhi_epi16(row12Temp1, row12Temp2));// A2 B2 C2 D2 E2 F2 G2 H2
__m256h result = _mm256_add_ph(row1, row2);
VectorF16<16, 1> vec(result);
vec = vec.template Shuffle<{{0,1,2,3,8,9,10,11,4,5,6,7,12,13,14,15}}>();
return VectorF16<1, 16>(vec.v);
return result;
} else {
__m512h mulA = _mm512_mul_ph(A0.v, A1.v);
__m512h mulE = _mm512_mul_ph(E0.v, E1.v);
@ -1414,12 +1550,11 @@ namespace Crafter {
__m512h row1 = _mm512_castsi512_ph(_mm512_unpacklo_epi16(row12Temp1, row12Temp2));// A1 B1 C1 D1 E1 F1 G1 H1
__m512h row2 = _mm512_castsi512_ph(_mm512_unpackhi_epi16(row12Temp1, row12Temp2));// A2 B2 C2 D2 E2 F2 G2 H2
__m512h result = _mm512_add_ph(row1, row2);
VectorF16<32, 1> vec(result);
vec = vec.template Shuffle<{{0,1,2,3,8,9,10,11,16,17,18,19,24,25,26,27,4,5,6,7,12,13,14,15,20,21,22,23,28,29,30,31}}>();
return VectorF16<1, 32>(vec.v);
return result;
}
}
public:
template <std::array<bool, Len> ShuffleValues>
constexpr static VectorF16<Len, Packing> Blend(VectorF16<Len, Packing> a, VectorF16<Len, Packing> b) {

View file

@ -412,6 +412,51 @@ std::string* TestAllCombinations() {
}
}
if constexpr(Len == 4 && Packing*Len == VectorType<Len, Packing>::Alignment) {
{
VectorType<Len, Packing> vecA(floats);
VectorType<Len, Packing> vecC = vecA * 2;
VectorType<Len, Packing> vecE = vecA * 3;
VectorType<Len, Packing> vecG = vecA * 4;
VectorType<1, Packing*4> result = VectorType<Len, Packing>::Length(vecA, vecC, vecE, vecG);
Vector<T, Packing*4, VectorType<Len, Packing>::Alignment> stored = result.Store();
if (!FloatEquals(stored.v[0], expectedLength[0])) {
return new std::string(std::format("Length 4 vecA test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0], (float)stored.v[0]));
}
if (!FloatEquals(stored.v[Packing], expectedLength[0] * 2)) {
return new std::string(std::format("Length 4 vecC test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0] * 2, (float)stored.v[Packing]));
}
if (!FloatEquals(stored.v[Packing*2], expectedLength[0] * 3)) {
return new std::string(std::format("Length 4 vecE test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0] * 3, (float)stored.v[Packing*2]));
}
if (!FloatEquals(stored.v[Packing*3], expectedLength[0] * 4)) {
return new std::string(std::format("Length 4 vecG test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0] * 4, (float)stored.v[Packing*3]));
}
}
{
VectorType<Len, Packing> vecA(floats);
VectorType<Len, Packing> vecC = vecA * 2;
VectorType<Len, Packing> vecE = vecA * 3;
VectorType<Len, Packing> vecG = vecA * 4;
auto result = VectorType<Len, Packing>::Normalize(vecA, vecC, vecE, vecG);
VectorType<1, Packing*4> result2 = VectorType<Len, Packing>::Length(std::get<0>(result), std::get<1>(result), std::get<2>(result), std::get<3>(result));
Vector<T, Packing*4, VectorType<Len, Packing>::Alignment> stored = result2.Store();
//std::println("{}", stored);
for(std::uint8_t i = 0; i < Len*Packing; i++) {
if (!FloatEquals(stored.v[i], T(1))) {
return new std::string(std::format("Normalize {} test failed at Len={} Packing={} Expected: {}, Got: {}", i, Len, Packing, 1, (float)stored.v[i]));
}
}
}
}
return TestAllCombinations<T, VectorType, MaxSize, Len, Packing + 1>();
}
}