added more tests
This commit is contained in:
parent
bd9230e07e
commit
d09155736f
2 changed files with 61 additions and 417 deletions
|
|
@ -745,148 +745,6 @@ namespace Crafter {
|
|||
}
|
||||
}
|
||||
|
||||
constexpr static std::tuple<VectorF16<Len, Packing>, VectorF16<Len, Packing>, VectorF16<Len, Packing>, VectorF16<Len, Packing>, VectorF16<Len, Packing>, VectorF16<Len, Packing>, VectorF16<Len, Packing>, VectorF16<Len, Packing>> Normalize(
|
||||
VectorF16<Len, Packing> A,
|
||||
VectorF16<Len, Packing> B,
|
||||
VectorF16<Len, Packing> C,
|
||||
VectorF16<Len, Packing> D,
|
||||
VectorF16<Len, Packing> E,
|
||||
VectorF16<Len, Packing> F,
|
||||
VectorF16<Len, Packing> G,
|
||||
VectorF16<Len, Packing> H
|
||||
) requires(Len == 8 && Packing*Len == Alignment) {
|
||||
constexpr std::array<std::uint8_t, Alignment*2> shuffleMaskA = GetShuffleMaskEpi8<{{0,0,0,0,0,0,0,0}}>();
|
||||
constexpr std::array<std::uint8_t, Alignment*2> shuffleMaskB = GetShuffleMaskEpi8<{{1,1,1,1,1,1,1,1}}>();
|
||||
constexpr std::array<std::uint8_t, Alignment*2> shuffleMaskC = GetShuffleMaskEpi8<{{2,2,2,2,2,2,2,2}}>();
|
||||
constexpr std::array<std::uint8_t, Alignment*2> shuffleMaskD = GetShuffleMaskEpi8<{{3,3,3,3,3,3,3,3}}>();
|
||||
constexpr std::array<std::uint8_t, Alignment*2> shuffleMaskE = GetShuffleMaskEpi8<{{4,4,4,4,4,4,4,4}}>();
|
||||
constexpr std::array<std::uint8_t, Alignment*2> shuffleMaskF = GetShuffleMaskEpi8<{{5,5,5,5,5,5,5,5}}>();
|
||||
constexpr std::array<std::uint8_t, Alignment*2> shuffleMaskG = GetShuffleMaskEpi8<{{6,6,6,6,6,6,6,6}}>();
|
||||
constexpr std::array<std::uint8_t, Alignment*2> shuffleMaskH = GetShuffleMaskEpi8<{{7,7,7,7,7,7,7,7}}>();
|
||||
|
||||
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
||||
VectorF16<Len, Packing> lenght = Length(A, B, C, D, E, F, G, H);
|
||||
constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1};
|
||||
__m128h one = _mm_loadu_ph(oneArr);
|
||||
__m128h fLenght = _mm_div_ph(one, lenght.v);
|
||||
|
||||
__m128i shuffleVecA = _mm_loadu_epi8(shuffleMaskA.data());
|
||||
__m128h fLenghtA = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecA));
|
||||
|
||||
__m128i shuffleVecB = _mm_loadu_epi8(shuffleMaskB.data());
|
||||
__m128h fLenghtB = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecB));
|
||||
|
||||
__m128i shuffleVecC = _mm_loadu_epi8(shuffleMaskC.data());
|
||||
__m128h fLenghtC = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecC));
|
||||
|
||||
__m128i shuffleVecD = _mm_loadu_epi8(shuffleMaskD.data());
|
||||
__m128h fLenghtD = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecD));
|
||||
|
||||
__m128i shuffleVecE = _mm_loadu_epi8(shuffleMaskE.data());
|
||||
__m128h fLenghtE = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecE));
|
||||
|
||||
__m128i shuffleVecF = _mm_loadu_epi8(shuffleMaskF.data());
|
||||
__m128h fLenghtF = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecF));
|
||||
|
||||
__m128i shuffleVecG = _mm_loadu_epi8(shuffleMaskG.data());
|
||||
__m128h fLenghtG = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecG));
|
||||
|
||||
__m128i shuffleVecH = _mm_loadu_epi8(shuffleMaskH.data());
|
||||
__m128h fLenghtH = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecH));
|
||||
|
||||
return {
|
||||
_mm_mul_ph(A.v, fLenghtA),
|
||||
_mm_mul_ph(B.v, fLenghtB),
|
||||
_mm_mul_ph(C.v, fLenghtC),
|
||||
_mm_mul_ph(D.v, fLenghtD),
|
||||
_mm_mul_ph(E.v, fLenghtE),
|
||||
_mm_mul_ph(F.v, fLenghtF),
|
||||
_mm_mul_ph(G.v, fLenghtG),
|
||||
_mm_mul_ph(H.v, fLenghtH)
|
||||
};
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
||||
VectorF16<Len, Packing> lenght = Length(A, B, C, D, E, F, G, H);
|
||||
constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
|
||||
__m256h one = _mm256_loadu_ph(oneArr);
|
||||
__m256h fLenght = _mm256_div_ph(one, lenght.v);
|
||||
|
||||
__m256i shuffleVecA = _mm256_loadu_epi8(shuffleMaskA.data());
|
||||
__m256h fLenghtA = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecA));
|
||||
|
||||
__m256i shuffleVecB = _mm256_loadu_epi8(shuffleMaskB.data());
|
||||
__m256h fLenghtB = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecB));
|
||||
|
||||
__m256i shuffleVecC = _mm256_loadu_epi8(shuffleMaskC.data());
|
||||
__m256h fLenghtC = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecC));
|
||||
|
||||
__m256i shuffleVecD = _mm256_loadu_epi8(shuffleMaskD.data());
|
||||
__m256h fLenghtD = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecD));
|
||||
|
||||
__m256i shuffleVecE = _mm256_loadu_epi8(shuffleMaskE.data());
|
||||
__m256h fLenghtE = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecE));
|
||||
|
||||
__m256i shuffleVecF = _mm256_loadu_epi8(shuffleMaskF.data());
|
||||
__m256h fLenghtF = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecF));
|
||||
|
||||
__m256i shuffleVecG = _mm256_loadu_epi8(shuffleMaskG.data());
|
||||
__m256h fLenghtG = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecG));
|
||||
|
||||
__m256i shuffleVecH = _mm256_loadu_epi8(shuffleMaskH.data());
|
||||
__m256h fLenghtH = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecH));
|
||||
|
||||
return {
|
||||
_mm256_mul_ph(A.v, fLenghtA),
|
||||
_mm256_mul_ph(B.v, fLenghtB),
|
||||
_mm256_mul_ph(C.v, fLenghtC),
|
||||
_mm256_mul_ph(D.v, fLenghtD),
|
||||
_mm256_mul_ph(E.v, fLenghtE),
|
||||
_mm256_mul_ph(F.v, fLenghtF),
|
||||
_mm256_mul_ph(G.v, fLenghtG),
|
||||
_mm256_mul_ph(H.v, fLenghtH)
|
||||
};
|
||||
} else {
|
||||
VectorF16<Len, Packing> lenght = Length(A, B, C, D, E, F, G, H);
|
||||
constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
|
||||
__m512h one = _mm512_loadu_ph(oneArr);
|
||||
__m512h fLenght = _mm512_div_ph(one, lenght.v);
|
||||
|
||||
__m512i shuffleVecA = _mm512_loadu_epi8(shuffleMaskA.data());
|
||||
__m512h fLenghtA = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecA));
|
||||
|
||||
__m512i shuffleVecB = _mm512_loadu_epi8(shuffleMaskB.data());
|
||||
__m512h fLenghtB = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecB));
|
||||
|
||||
__m512i shuffleVecC = _mm512_loadu_epi8(shuffleMaskC.data());
|
||||
__m512h fLenghtC = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecC));
|
||||
|
||||
__m512i shuffleVecD = _mm512_loadu_epi8(shuffleMaskD.data());
|
||||
__m512h fLenghtD = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecD));
|
||||
|
||||
__m512i shuffleVecE = _mm512_loadu_epi8(shuffleMaskE.data());
|
||||
__m512h fLenghtE = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecE));
|
||||
|
||||
__m512i shuffleVecF = _mm512_loadu_epi8(shuffleMaskF.data());
|
||||
__m512h fLenghtF = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecF));
|
||||
|
||||
__m512i shuffleVecG = _mm512_loadu_epi8(shuffleMaskG.data());
|
||||
__m512h fLenghtG = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecG));
|
||||
|
||||
__m512i shuffleVecH = _mm512_loadu_epi8(shuffleMaskH.data());
|
||||
__m512h fLenghtH = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecH));
|
||||
|
||||
return {
|
||||
_mm512_mul_ph(A.v, fLenghtA),
|
||||
_mm512_mul_ph(B.v, fLenghtB),
|
||||
_mm512_mul_ph(C.v, fLenghtC),
|
||||
_mm512_mul_ph(D.v, fLenghtD),
|
||||
_mm512_mul_ph(E.v, fLenghtE),
|
||||
_mm512_mul_ph(F.v, fLenghtF),
|
||||
_mm512_mul_ph(G.v, fLenghtG),
|
||||
_mm512_mul_ph(H.v, fLenghtH)
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
constexpr static std::tuple<VectorF16<Len, Packing>, VectorF16<Len, Packing>, VectorF16<Len, Packing>, VectorF16<Len, Packing>> Normalize(
|
||||
VectorF16<Len, Packing> A,
|
||||
VectorF16<Len, Packing> C,
|
||||
|
|
@ -995,26 +853,6 @@ namespace Crafter {
|
|||
}
|
||||
}
|
||||
|
||||
constexpr static VectorF16<1, Packing*8> Length(
|
||||
VectorF16<Len, Packing> A,
|
||||
VectorF16<Len, Packing> B,
|
||||
VectorF16<Len, Packing> C,
|
||||
VectorF16<Len, Packing> D,
|
||||
VectorF16<Len, Packing> E,
|
||||
VectorF16<Len, Packing> F,
|
||||
VectorF16<Len, Packing> G,
|
||||
VectorF16<Len, Packing> H
|
||||
) requires(Len == 8 && Packing*Len == Alignment) {
|
||||
VectorF16<1, Packing*8> lenghtSq = LengthSq(A, B, C, D, E, F, G, H);
|
||||
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
||||
return VectorF16<1, Packing*8>(_mm_sqrt_ph(lenghtSq.v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
||||
return VectorF16<1, Packing*8>(_mm256_sqrt_ph(lenghtSq.v));
|
||||
} else {
|
||||
return VectorF16<1, Packing*8>(_mm512_sqrt_ph(lenghtSq.v));
|
||||
}
|
||||
}
|
||||
|
||||
constexpr static VectorF16<1, Packing*4> Length(
|
||||
VectorF16<Len, Packing> A,
|
||||
VectorF16<Len, Packing> C,
|
||||
|
|
@ -1045,19 +883,6 @@ namespace Crafter {
|
|||
}
|
||||
}
|
||||
|
||||
constexpr static VectorF16<1, Packing*8> LengthSq(
|
||||
VectorF16<Len, Packing> A,
|
||||
VectorF16<Len, Packing> B,
|
||||
VectorF16<Len, Packing> C,
|
||||
VectorF16<Len, Packing> D,
|
||||
VectorF16<Len, Packing> E,
|
||||
VectorF16<Len, Packing> F,
|
||||
VectorF16<Len, Packing> G,
|
||||
VectorF16<Len, Packing> H
|
||||
) requires(Len == 8 && Packing*Len == Alignment) {
|
||||
return Dot(A, A, B, B, C, C, D, D, E, E, F, F, G, G, H, H);
|
||||
}
|
||||
|
||||
constexpr static VectorF16<1, Packing*4> LengthSq(
|
||||
VectorF16<Len, Packing> A,
|
||||
VectorF16<Len, Packing> C,
|
||||
|
|
@ -1074,29 +899,6 @@ namespace Crafter {
|
|||
return Dot(A, A, E, E);
|
||||
}
|
||||
|
||||
constexpr static VectorF16<1, Packing*8> Dot(
|
||||
VectorF16<Len, Packing> A0, VectorF16<Len, Packing> A1,
|
||||
VectorF16<Len, Packing> B0, VectorF16<Len, Packing> B1,
|
||||
VectorF16<Len, Packing> C0, VectorF16<Len, Packing> C1,
|
||||
VectorF16<Len, Packing> D0, VectorF16<Len, Packing> D1,
|
||||
VectorF16<Len, Packing> E0, VectorF16<Len, Packing> E1,
|
||||
VectorF16<Len, Packing> F0, VectorF16<Len, Packing> F1,
|
||||
VectorF16<Len, Packing> G0, VectorF16<Len, Packing> G1,
|
||||
VectorF16<Len, Packing> H0, VectorF16<Len, Packing> H1
|
||||
) requires(Len == 8 && Packing*Len == Alignment) {
|
||||
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
||||
return DotNoShuffle(A0, A1, B0, B1, C0, C1, D0, D1, E0, E1, G0, G1, H0, H1);
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
||||
VectorF16<16, 1> vec(DotNoShuffle(A0, A1, B0, B1, C0, C1, D0, D1, E0, E1, G0, G1, H0, H1).v);
|
||||
vec = vec.template Shuffle<{{0,1,2,3,8,9,10,11,4,5,6,7,12,13,14,15}}>();
|
||||
return vec.v;
|
||||
} else {
|
||||
VectorF16<32, 1> vec(DotNoShuffle(A0, A1, B0, B1, C0, C1, D0, D1, E0, E1, G0, G1, H0, H1).v);
|
||||
vec = vec.template Shuffle<{{0,1,2,3,8,9,10,11,16,17,18,19,24,25,26,27,4,5,6,7,12,13,14,15,20,21,22,23,28,29,30,31}}>();
|
||||
return vec.v;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr static VectorF16<1, Packing*4> Dot(
|
||||
VectorF16<Len, Packing> A0, VectorF16<Len, Packing> A1,
|
||||
VectorF16<Len, Packing> C0, VectorF16<Len, Packing> C1,
|
||||
|
|
@ -1152,26 +954,6 @@ namespace Crafter {
|
|||
|
||||
|
||||
private:
|
||||
constexpr static VectorF16<1, Packing*8> LengthNoShuffle(
|
||||
VectorF16<Len, Packing> A,
|
||||
VectorF16<Len, Packing> B,
|
||||
VectorF16<Len, Packing> C,
|
||||
VectorF16<Len, Packing> D,
|
||||
VectorF16<Len, Packing> E,
|
||||
VectorF16<Len, Packing> F,
|
||||
VectorF16<Len, Packing> G,
|
||||
VectorF16<Len, Packing> H
|
||||
) requires(Len == 8 && Packing*Len == Alignment) {
|
||||
VectorF16<1, Packing*8> lenghtSq = LengthSqNoShuffle(A, B, C, D, E, F, G, H);
|
||||
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
||||
return VectorF16<1, Packing*8>(_mm_sqrt_ph(lenghtSq.v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
||||
return VectorF16<1, Packing*8>(_mm256_sqrt_ph(lenghtSq.v));
|
||||
} else {
|
||||
return VectorF16<1, Packing*8>(_mm512_sqrt_ph(lenghtSq.v));
|
||||
}
|
||||
}
|
||||
|
||||
constexpr static VectorF16<1, Packing*4> LengthNoShuffle(
|
||||
VectorF16<Len, Packing> A,
|
||||
VectorF16<Len, Packing> C,
|
||||
|
|
@ -1202,19 +984,6 @@ namespace Crafter {
|
|||
}
|
||||
}
|
||||
|
||||
constexpr static VectorF16<1, Packing*8> LengthSqNoShuffle(
|
||||
VectorF16<Len, Packing> A,
|
||||
VectorF16<Len, Packing> B,
|
||||
VectorF16<Len, Packing> C,
|
||||
VectorF16<Len, Packing> D,
|
||||
VectorF16<Len, Packing> E,
|
||||
VectorF16<Len, Packing> F,
|
||||
VectorF16<Len, Packing> G,
|
||||
VectorF16<Len, Packing> H
|
||||
) requires(Len == 8 && Packing*Len == Alignment) {
|
||||
return DotNoShuffle(A, A, B, B, C, C, D, D, E, E, F, F, G, G, H, H);
|
||||
}
|
||||
|
||||
constexpr static VectorF16<1, Packing*4> LengthSqNoShuffle(
|
||||
VectorF16<Len, Packing> A,
|
||||
VectorF16<Len, Packing> C,
|
||||
|
|
@ -1231,177 +1000,6 @@ namespace Crafter {
|
|||
return DotNoShuffle(A, A, E, E);
|
||||
}
|
||||
|
||||
constexpr static VectorF16<1, Packing*8> DotNoShuffle(
|
||||
VectorF16<Len, Packing> A0, VectorF16<Len, Packing> A1,
|
||||
VectorF16<Len, Packing> B0, VectorF16<Len, Packing> B1,
|
||||
VectorF16<Len, Packing> C0, VectorF16<Len, Packing> C1,
|
||||
VectorF16<Len, Packing> D0, VectorF16<Len, Packing> D1,
|
||||
VectorF16<Len, Packing> E0, VectorF16<Len, Packing> E1,
|
||||
VectorF16<Len, Packing> F0, VectorF16<Len, Packing> F1,
|
||||
VectorF16<Len, Packing> G0, VectorF16<Len, Packing> G1,
|
||||
VectorF16<Len, Packing> H0, VectorF16<Len, Packing> H1
|
||||
) requires(Len == 8 && Packing*Len == Alignment) {
|
||||
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
||||
__m128h mulA = _mm_mul_ph(A0.v, A1.v);
|
||||
__m128h mulB = _mm_mul_ph(B0.v, B1.v);
|
||||
__m128i row12Temp1 = _mm_unpacklo_epi16(_mm_castph_si128(mulA), _mm_castph_si128(mulB)); // A1 B1 A2 B2 A3 B3 A4 B4
|
||||
__m128i row56Temp1 = _mm_unpackhi_epi16(_mm_castph_si128(mulA), _mm_castph_si128(mulB)); // A5 B5 A6 B6 A7 B7 A8 B8
|
||||
__m128i row1TempTemp1 = row12Temp1;
|
||||
__m128i row5TempTemp1 = row56Temp1;
|
||||
|
||||
__m128h mulC = _mm_mul_ph(C0.v, C1.v);
|
||||
__m128h mulD = _mm_mul_ph(D0.v, D1.v);
|
||||
__m128i row34Temp1 = _mm_unpacklo_epi16(_mm_castph_si128(mulC), _mm_castph_si128(mulD)); // C1 D1 C2 D2 C3 D3 C4 D4
|
||||
__m128i row78Temp1 = _mm_unpackhi_epi16(_mm_castph_si128(mulA), _mm_castph_si128(mulB)); // C5 D5 C6 D6 C7 D7 C8 D8
|
||||
|
||||
row12Temp1 = _mm_unpacklo_epi16(row12Temp1, row34Temp1); // A1 C1 B1 D1 A2 C2 B2 D2
|
||||
row34Temp1 = _mm_unpackhi_epi16(row1TempTemp1, row34Temp1); // A3 C3 B3 D3 A4 C4 B4 D4
|
||||
row56Temp1 = _mm_unpacklo_epi16(row56Temp1, row78Temp1); // A5 C5 B5 D5 A6 C6 B6 D6
|
||||
row78Temp1 = _mm_unpackhi_epi16(row5TempTemp1, row78Temp1); // A7 C7 B7 D7 A8 C8 B8 D8
|
||||
|
||||
__m128h mulE = _mm_mul_ph(E0.v, E1.v);
|
||||
__m128h mulF = _mm_mul_ph(F0.v, F1.v);
|
||||
__m128i row12Temp2 = _mm_unpacklo_epi16(_mm_castph_si128(mulE), _mm_castph_si128(mulF)); //E1 F1 E2 F2 E3 F3 E4 F4
|
||||
__m128i row56Temp2 = _mm_unpackhi_epi16(_mm_castph_si128(mulE), _mm_castph_si128(mulF)); //E5 F5 E6 F6 E7 F7 E8 F8
|
||||
__m128i row1TempTemp2 = row12Temp2;
|
||||
__m128i row5TempTemp2 = row56Temp2;
|
||||
|
||||
__m128h mulG = _mm_mul_ph(G0.v, G1.v);
|
||||
__m128h mulH = _mm_mul_ph(H0.v, H1.v);
|
||||
__m128i row34Temp2 = _mm_unpacklo_epi16(_mm_castph_si128(mulG), _mm_castph_si128(mulH)); //G1 H1 G2 H2 G3 H3 G4 H4
|
||||
__m128i row78Temp2 = _mm_unpackhi_epi16(_mm_castph_si128(mulE), _mm_castph_si128(mulF)); //G5 H5 G6 H6 G7 H7 G8 H8
|
||||
|
||||
row12Temp2 = _mm_unpacklo_epi16(row12Temp2, row34Temp2); // E1 G1 F1 H1 E2 G2 F2 H2
|
||||
row34Temp2 = _mm_unpackhi_epi16(row1TempTemp2, row34Temp2); // E3 G3 F3 H3 E4 G4 F4 H4
|
||||
row56Temp2 = _mm_unpacklo_epi16(row56Temp2, row78Temp2); // E5 G5 F5 H5 E6 G6 F6 H6
|
||||
row78Temp2 = _mm_unpackhi_epi16(row5TempTemp2, row78Temp2); // E7 G7 F7 H7 E8 G8 F8 H8
|
||||
|
||||
__m128h row1 = _mm_castsi128_ph(_mm_unpackhi_epi16(row12Temp1, row12Temp2));// A1 E1 C1 G1 B1 F1 D1 H1
|
||||
__m128h row2 = _mm_castsi128_ph(_mm_unpacklo_epi16(row12Temp1, row12Temp2));// A2 E2 C2 G2 B2 F2 D2 H2
|
||||
__m128h row3 = _mm_castsi128_ph(_mm_unpackhi_epi16(row34Temp1, row34Temp2));// A3 E3 C3 G3 B3 F3 D3 H3
|
||||
__m128h row4 = _mm_castsi128_ph(_mm_unpacklo_epi16(row34Temp1, row34Temp2));// A4 E4 C4 G4 B4 F4 D4 H4
|
||||
__m128h row5 = _mm_castsi128_ph(_mm_unpackhi_epi16(row56Temp1, row56Temp2));// A5 E5 C5 G5 B5 F5 D5 H5
|
||||
__m128h row6 = _mm_castsi128_ph(_mm_unpacklo_epi16(row56Temp1, row56Temp2));// A6 E6 C6 G6 B6 F6 D6 H6
|
||||
__m128h row7 = _mm_castsi128_ph(_mm_unpackhi_epi16(row78Temp1, row78Temp2));// A7 E7 C7 G7 B7 F7 D7 H7
|
||||
__m128h row8 = _mm_castsi128_ph(_mm_unpacklo_epi16(row78Temp1, row78Temp2));// A8 E8 C8 G8 B8 F8 D8 H8
|
||||
|
||||
row1 = _mm_add_ph(row1, row2);
|
||||
row1 = _mm_add_ph(row1, row3);
|
||||
row1 = _mm_add_ph(row1, row4);
|
||||
row1 = _mm_add_ph(row1, row5);
|
||||
row1 = _mm_add_ph(row1, row6);
|
||||
row1 = _mm_add_ph(row1, row7);
|
||||
row1 = _mm_add_ph(row1, row8);
|
||||
|
||||
return row1;
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
||||
__m256h mulA = _mm256_mul_ph(A0.v, A1.v);
|
||||
__m256h mulB = _mm256_mul_ph(B0.v, B1.v);
|
||||
__m256i row12Temp1 = _mm256_unpacklo_epi16(_mm256_castph_si256(mulA), _mm256_castph_si256(mulB)); // A1 B1 A2 B2 A3 B3 A4 B4
|
||||
__m256i row56Temp1 = _mm256_unpackhi_epi16(_mm256_castph_si256(mulA), _mm256_castph_si256(mulB)); // A5 B5 A6 B6 A7 B7 A8 B8
|
||||
__m256i row1TempTemp1 = row12Temp1;
|
||||
__m256i row5TempTemp1 = row56Temp1;
|
||||
|
||||
__m256h mulC = _mm256_mul_ph(C0.v, C1.v);
|
||||
__m256h mulD = _mm256_mul_ph(D0.v, D1.v);
|
||||
__m256i row34Temp1 = _mm256_unpacklo_epi16(_mm256_castph_si256(mulC), _mm256_castph_si256(mulD)); // C1 D1 C2 D2 C3 D3 C4 D4
|
||||
__m256i row78Temp1 = _mm256_unpackhi_epi16(_mm256_castph_si256(mulA), _mm256_castph_si256(mulB)); // C5 D5 C6 D6 C7 D7 C8 D8
|
||||
|
||||
row12Temp1 = _mm256_unpacklo_epi16(row12Temp1, row34Temp1); // A1 C1 B1 D1 A2 C2 B2 D2
|
||||
row34Temp1 = _mm256_unpackhi_epi16(row1TempTemp1, row34Temp1); // A3 C3 B3 D3 A4 C4 B4 D4
|
||||
row56Temp1 = _mm256_unpacklo_epi16(row56Temp1, row78Temp1); // A5 C5 B5 D5 A6 C6 B6 D6
|
||||
row78Temp1 = _mm256_unpackhi_epi16(row5TempTemp1, row78Temp1); // A7 C7 B7 D7 A8 C8 B8 D8
|
||||
|
||||
__m256h mulE = _mm256_mul_ph(E0.v, E1.v);
|
||||
__m256h mulF = _mm256_mul_ph(F0.v, F1.v);
|
||||
__m256i row12Temp2 = _mm256_unpacklo_epi16(_mm256_castph_si256(mulE), _mm256_castph_si256(mulF)); //E1 F1 E2 F2 E3 F3 E4 F4
|
||||
__m256i row56Temp2 = _mm256_unpackhi_epi16(_mm256_castph_si256(mulE), _mm256_castph_si256(mulF)); //E5 F5 E6 F6 E7 F7 E8 F8
|
||||
__m256i row1TempTemp2 = row12Temp2;
|
||||
__m256i row5TempTemp2 = row56Temp2;
|
||||
|
||||
__m256h mulG = _mm256_mul_ph(G0.v, G1.v);
|
||||
__m256h mulH = _mm256_mul_ph(H0.v, H1.v);
|
||||
__m256i row34Temp2 = _mm256_unpacklo_epi16(_mm256_castph_si256(mulG), _mm256_castph_si256(mulH)); //G1 H1 G2 H2 G3 H3 G4 H4
|
||||
__m256i row78Temp2 = _mm256_unpackhi_epi16(_mm256_castph_si256(mulE), _mm256_castph_si256(mulF)); //G5 H5 G6 H6 G7 H7 G8 H8
|
||||
|
||||
row12Temp2 = _mm256_unpacklo_epi16(row12Temp2, row34Temp2); // E1 G1 F1 H1 E2 G2 F2 H2
|
||||
row34Temp2 = _mm256_unpackhi_epi16(row1TempTemp2, row34Temp2); // E3 G3 F3 H3 E4 G4 F4 H4
|
||||
row56Temp2 = _mm256_unpacklo_epi16(row56Temp2, row78Temp2); // E5 G5 F5 H5 E6 G6 F6 H6
|
||||
row78Temp2 = _mm256_unpackhi_epi16(row5TempTemp2, row78Temp2); // E7 G7 F7 H7 E8 G8 F8 H8
|
||||
|
||||
__m256h row1 = _mm256_castsi256_ph(_mm256_unpackhi_epi16(row12Temp1, row12Temp2));// A1 E1 C1 G1 B1 F1 D1 H1
|
||||
__m256h row2 = _mm256_castsi256_ph(_mm256_unpacklo_epi16(row12Temp1, row12Temp2));// A2 E2 C2 G2 B2 F2 D2 H2
|
||||
__m256h row3 = _mm256_castsi256_ph(_mm256_unpackhi_epi16(row34Temp1, row34Temp2));// A3 E3 C3 G3 B3 F3 D3 H3
|
||||
__m256h row4 = _mm256_castsi256_ph(_mm256_unpacklo_epi16(row34Temp1, row34Temp2));// A4 E4 C4 G4 B4 F4 D4 H4
|
||||
__m256h row5 = _mm256_castsi256_ph(_mm256_unpackhi_epi16(row56Temp1, row56Temp2));// A5 E5 C5 G5 B5 F5 D5 H5
|
||||
__m256h row6 = _mm256_castsi256_ph(_mm256_unpacklo_epi16(row56Temp1, row56Temp2));// A6 E6 C6 G6 B6 F6 D6 H6
|
||||
__m256h row7 = _mm256_castsi256_ph(_mm256_unpackhi_epi16(row78Temp1, row78Temp2));// A7 E7 C7 G7 B7 F7 D7 H7
|
||||
__m256h row8 = _mm256_castsi256_ph(_mm256_unpacklo_epi16(row78Temp1, row78Temp2));// A8 E8 C8 G8 B8 F8 D8 H8
|
||||
|
||||
row1 = _mm256_add_ph(row1, row2);
|
||||
row1 = _mm256_add_ph(row1, row3);
|
||||
row1 = _mm256_add_ph(row1, row4);
|
||||
row1 = _mm256_add_ph(row1, row5);
|
||||
row1 = _mm256_add_ph(row1, row6);
|
||||
row1 = _mm256_add_ph(row1, row7);
|
||||
row1 = _mm256_add_ph(row1, row8);
|
||||
|
||||
return row1;
|
||||
} else {
|
||||
__m512h mulA = _mm512_mul_ph(A0.v, A1.v);
|
||||
__m512h mulB = _mm512_mul_ph(B0.v, B1.v);
|
||||
__m512i row12Temp1 = _mm512_unpacklo_epi16(_mm512_castph_si512(mulA), _mm512_castph_si512(mulB)); // A1 B1 A2 B2 A3 B3 A4 B4
|
||||
__m512i row56Temp1 = _mm512_unpackhi_epi16(_mm512_castph_si512(mulA), _mm512_castph_si512(mulB)); // A5 B5 A6 B6 A7 B7 A8 B8
|
||||
__m512i row1TempTemp1 = row12Temp1;
|
||||
__m512i row5TempTemp1 = row56Temp1;
|
||||
|
||||
__m512h mulC = _mm512_mul_ph(C0.v, C1.v);
|
||||
__m512h mulD = _mm512_mul_ph(D0.v, D1.v);
|
||||
__m512i row34Temp1 = _mm512_unpacklo_epi16(_mm512_castph_si512(mulC), _mm512_castph_si512(mulD)); // C1 D1 C2 D2 C3 D3 C4 D4
|
||||
__m512i row78Temp1 = _mm512_unpackhi_epi16(_mm512_castph_si512(mulA), _mm512_castph_si512(mulB)); // C5 D5 C6 D6 C7 D7 C8 D8
|
||||
|
||||
row12Temp1 = _mm512_unpacklo_epi16(row12Temp1, row34Temp1); // A1 C1 B1 D1 A2 C2 B2 D2
|
||||
row34Temp1 = _mm512_unpackhi_epi16(row1TempTemp1, row34Temp1); // A3 C3 B3 D3 A4 C4 B4 D4
|
||||
row56Temp1 = _mm512_unpacklo_epi16(row56Temp1, row78Temp1); // A5 C5 B5 D5 A6 C6 B6 D6
|
||||
row78Temp1 = _mm512_unpackhi_epi16(row5TempTemp1, row78Temp1); // A7 C7 B7 D7 A8 C8 B8 D8
|
||||
|
||||
__m512h mulE = _mm512_mul_ph(E0.v, E1.v);
|
||||
__m512h mulF = _mm512_mul_ph(F0.v, F1.v);
|
||||
__m512i row12Temp2 = _mm512_unpacklo_epi16(_mm512_castph_si512(mulE), _mm512_castph_si512(mulF)); //E1 F1 E2 F2 E3 F3 E4 F4
|
||||
__m512i row56Temp2 = _mm512_unpackhi_epi16(_mm512_castph_si512(mulE), _mm512_castph_si512(mulF)); //E5 F5 E6 F6 E7 F7 E8 F8
|
||||
__m512i row1TempTemp2 = row12Temp2;
|
||||
__m512i row5TempTemp2 = row56Temp2;
|
||||
|
||||
__m512h mulG = _mm512_mul_ph(G0.v, G1.v);
|
||||
__m512h mulH = _mm512_mul_ph(H0.v, H1.v);
|
||||
__m512i row34Temp2 = _mm512_unpacklo_epi16(_mm512_castph_si512(mulG), _mm512_castph_si512(mulH)); //G1 H1 G2 H2 G3 H3 G4 H4
|
||||
__m512i row78Temp2 = _mm512_unpackhi_epi16(_mm512_castph_si512(mulE), _mm512_castph_si512(mulF)); //G5 H5 G6 H6 G7 H7 G8 H8
|
||||
|
||||
row12Temp2 = _mm512_unpacklo_epi16(row12Temp2, row34Temp2); // E1 G1 F1 H1 E2 G2 F2 H2
|
||||
row34Temp2 = _mm512_unpackhi_epi16(row1TempTemp2, row34Temp2); // E3 G3 F3 H3 E4 G4 F4 H4
|
||||
row56Temp2 = _mm512_unpacklo_epi16(row56Temp2, row78Temp2); // E5 G5 F5 H5 E6 G6 F6 H6
|
||||
row78Temp2 = _mm512_unpackhi_epi16(row5TempTemp2, row78Temp2); // E7 G7 F7 H7 E8 G8 F8 H8
|
||||
|
||||
__m512h row1 = _mm512_castsi512_ph(_mm512_unpackhi_epi16(row12Temp1, row12Temp2));// A1 E1 C1 G1 B1 F1 D1 H1
|
||||
__m512h row2 = _mm512_castsi512_ph(_mm512_unpacklo_epi16(row12Temp1, row12Temp2));// A2 E2 C2 G2 B2 F2 D2 H2
|
||||
__m512h row3 = _mm512_castsi512_ph(_mm512_unpackhi_epi16(row34Temp1, row34Temp2));// A3 E3 C3 G3 B3 F3 D3 H3
|
||||
__m512h row4 = _mm512_castsi512_ph(_mm512_unpacklo_epi16(row34Temp1, row34Temp2));// A4 E4 C4 G4 B4 F4 D4 H4
|
||||
__m512h row5 = _mm512_castsi512_ph(_mm512_unpackhi_epi16(row56Temp1, row56Temp2));// A5 E5 C5 G5 B5 F5 D5 H5
|
||||
__m512h row6 = _mm512_castsi512_ph(_mm512_unpacklo_epi16(row56Temp1, row56Temp2));// A6 E6 C6 G6 B6 F6 D6 H6
|
||||
__m512h row7 = _mm512_castsi512_ph(_mm512_unpackhi_epi16(row78Temp1, row78Temp2));// A7 E7 C7 G7 B7 F7 D7 H7
|
||||
__m512h row8 = _mm512_castsi512_ph(_mm512_unpacklo_epi16(row78Temp1, row78Temp2));// A8 E8 C8 G8 B8 F8 D8 H8
|
||||
|
||||
row1 = _mm512_add_ph(row1, row2);
|
||||
row1 = _mm512_add_ph(row1, row3);
|
||||
row1 = _mm512_add_ph(row1, row4);
|
||||
row1 = _mm512_add_ph(row1, row5);
|
||||
row1 = _mm512_add_ph(row1, row6);
|
||||
row1 = _mm512_add_ph(row1, row7);
|
||||
row1 = _mm512_add_ph(row1, row8);
|
||||
|
||||
return row1;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr static VectorF16<1, Packing*4> DotNoShuffle(
|
||||
VectorF16<Len, Packing> A0, VectorF16<Len, Packing> A1,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue