diff --git a/interfaces/Crafter.Math-VectorF16.cppm b/interfaces/Crafter.Math-VectorF16.cppm index 81b8557..3e8ff0f 100755 --- a/interfaces/Crafter.Math-VectorF16.cppm +++ b/interfaces/Crafter.Math-VectorF16.cppm @@ -893,82 +893,58 @@ namespace Crafter { VectorF16 E, VectorF16 G ) requires(Len == 4 && Packing*Len == Alignment) { - constexpr std::array shuffleMaskA = GetShuffleMaskEpi8<{{0,0,0,0}}>(); - constexpr std::array shuffleMaskC = GetShuffleMaskEpi8<{{1,1,1,1}}>(); - constexpr std::array shuffleMaskE = GetShuffleMaskEpi8<{{2,2,2,2}}>(); - constexpr std::array shuffleMaskG = GetShuffleMaskEpi8<{{3,3,3,3}}>(); - if constexpr(std::is_same_v) { - VectorF16 lenght = Length(A, C, E, G); + VectorF16<1, 8> lenght = LengthNoShuffle(A, E, C, G); constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1}; __m128h one = _mm_loadu_ph(oneArr); - __m128h fLenght = _mm_div_ph(one, lenght.v); + VectorF16<8, 1> fLenght(_mm_div_ph(one, lenght.v)); - __m128i shuffleVecA = _mm_loadu_epi8(shuffleMaskA.data()); - __m128h fLenghtA = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecA)); - - __m128i shuffleVecC = _mm_loadu_epi8(shuffleMaskC.data()); - __m128h fLenghtC = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecC)); - - __m128i shuffleVecE = _mm_loadu_epi8(shuffleMaskE.data()); - __m128h fLenghtE = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecE)); - - __m128i shuffleVecG = _mm_loadu_epi8(shuffleMaskG.data()); - __m128h fLenghtG = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecG)); + VectorF16<8, 1> fLenghtA = fLenght.template Shuffle<{{0,0,0,0,1,1,1,1}}>(); + VectorF16<8, 1> fLenghtC = fLenght.template Shuffle<{{2,2,2,2,3,3,3,3}}>(); + VectorF16<8, 1> fLenghtE = fLenght.template Shuffle<{{4,4,4,4,5,5,5,5}}>(); + VectorF16<8, 1> fLenghtG = fLenght.template Shuffle<{{6,6,6,6,7,7,7,7}}>(); return { - _mm_mul_ph(A.v, fLenghtA), - _mm_mul_ph(C.v, fLenghtC), - _mm_mul_ph(E.v, fLenghtE), - _mm_mul_ph(G.v, fLenghtG), + _mm_mul_ph(A.v, fLenghtA.v), + _mm_mul_ph(C.v, fLenghtC.v), + _mm_mul_ph(E.v, fLenghtE.v), + _mm_mul_ph(G.v, fLenghtG.v) }; } else if constexpr(std::is_same_v) { - VectorF16 lenght = Length(A, C, E, G); + VectorF16<1, 16> lenght = LengthNoShuffle(A, E, C, G); constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; __m256h one = _mm256_loadu_ph(oneArr); - __m256h fLenght = _mm256_div_ph(one, lenght.v); + VectorF16<16, 1> fLenght(_mm256_div_ph(one, lenght.v)); + - __m256i shuffleVecA = _mm256_loadu_epi8(shuffleMaskA.data()); - __m256h fLenghtA = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecA)); - - __m256i shuffleVecC = _mm256_loadu_epi8(shuffleMaskC.data()); - __m256h fLenghtC = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecC)); - - __m256i shuffleVecE = _mm256_loadu_epi8(shuffleMaskE.data()); - __m256h fLenghtE = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecE)); - - __m256i shuffleVecG = _mm256_loadu_epi8(shuffleMaskG.data()); - __m256h fLenghtG = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecG)); + VectorF16<16, 1> fLenghtA = fLenght.template Shuffle<{{0,0,0,0,1,1,1,1,8,8,8,8,9,9,9,9}}>(); + VectorF16<16, 1> fLenghtC = fLenght.template Shuffle<{{2,2,2,2,3,3,3,3,10,10,10,10,11,11,11,11}}>(); + VectorF16<16, 1> fLenghtE = fLenght.template Shuffle<{{4,4,4,4,5,5,5,5,12,12,12,12,13,13,13,13}}>(); + VectorF16<16, 1> fLenghtG = fLenght.template Shuffle<{{6,6,6,6,7,7,7,7,14,14,14,14,15,15,15,15}}>(); return { - _mm256_mul_ph(A.v, fLenghtA), - _mm256_mul_ph(C.v, fLenghtC), - _mm256_mul_ph(E.v, fLenghtE), - _mm256_mul_ph(G.v, fLenghtG), + _mm256_mul_ph(A.v, fLenghtA.v), + _mm256_mul_ph(C.v, fLenghtC.v), + _mm256_mul_ph(E.v, fLenghtE.v), + _mm256_mul_ph(G.v, fLenghtG.v) }; } else { - VectorF16 lenght = Length(A, C, E, G); + VectorF16<1, 32> lenght = LengthNoShuffle(A, E, C, G); constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; __m512h one = _mm512_loadu_ph(oneArr); - __m512h fLenght = _mm512_div_ph(one, lenght.v); + VectorF16<32, 1> fLenght(_mm512_div_ph(one, lenght.v)); + VectorF16<32, 1> fLenght2(lenght.v); - __m512i shuffleVecA = _mm512_loadu_epi8(shuffleMaskA.data()); - __m512h fLenghtA = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecA)); - - __m512i shuffleVecC = _mm512_loadu_epi8(shuffleMaskC.data()); - __m512h fLenghtC = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecC)); - - __m512i shuffleVecE = _mm512_loadu_epi8(shuffleMaskE.data()); - __m512h fLenghtE = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecE)); - - __m512i shuffleVecG = _mm512_loadu_epi8(shuffleMaskG.data()); - __m512h fLenghtG = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecG)); + VectorF16<32, 1> fLenghtA = fLenght.template Shuffle<{{0,0,0,0,1,1,1,1,8,8,8,8,9,9,9,9,16,16,16,16,17,17,17,17,24,24,24,24,25,25,25,25}}>(); + VectorF16<32, 1> fLenghtC = fLenght.template Shuffle<{{2,2,2,2,3,3,3,3,10,10,10,10,11,11,11,11,18,18,18,18,19,19,19,19,26,26,26,26,27,27,27,27}}>(); + VectorF16<32, 1> fLenghtE = fLenght.template Shuffle<{{4,4,4,4,5,5,5,5,12,12,12,12,13,13,13,13,20,20,20,20,21,21,21,21,28,28,28,28,29,29,29,29}}>(); + VectorF16<32, 1> fLenghtG = fLenght.template Shuffle<{{6,6,6,6,7,7,7,7,14,14,14,14,15,15,15,15,22,22,22,22,23,23,23,23,30,30,30,30,31,31,31,31}}>(); return { - VectorF16(_mm512_mul_ph(A.v, fLenghtA)), - VectorF16(_mm512_mul_ph(C.v, fLenghtC)), - VectorF16(_mm512_mul_ph(E.v, fLenghtE)), - VectorF16(_mm512_mul_ph(G.v, fLenghtG)), + VectorF16(_mm512_mul_ph(A.v, fLenghtA.v)), + VectorF16(_mm512_mul_ph(C.v, fLenghtC.v)), + VectorF16(_mm512_mul_ph(E.v, fLenghtE.v)), + VectorF16(_mm512_mul_ph(G.v, fLenghtG.v)), }; } } @@ -978,7 +954,7 @@ namespace Crafter { VectorF16 E ) requires(Len == 2 && Packing*Len == Alignment) { if constexpr(std::is_same_v) { - VectorF16<1, 8> lenght = Length(A, E); + VectorF16<1, 8> lenght = LengthNoShuffle(A, E); constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1}; __m128h one = _mm_loadu_ph(oneArr); VectorF16<8, 1> fLenght(_mm_div_ph(one, lenght.v)); @@ -991,26 +967,26 @@ namespace Crafter { _mm_mul_ph(E.v, fLenghtE.v), }; } else if constexpr(std::is_same_v) { - VectorF16<1, 16> lenght = Length(A, E); + VectorF16<1, 16> lenght = LengthNoShuffle(A, E); constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; __m256h one = _mm256_loadu_ph(oneArr); VectorF16<16, 1> fLenght(_mm256_div_ph(one, lenght.v)); - VectorF16<16, 1> fLenghtA = fLenght.template Shuffle<{{0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7}}>(); - VectorF16<16, 1> fLenghtE = fLenght.template Shuffle<{{8,8,9,9,10,10,11,11,12,12,13,13,14,14,15,15}}>(); + VectorF16<16, 1> fLenghtA = fLenght.template Shuffle<{{0,0,1,1,2,2,3,3,8,8,9,9,10,10,11,11}}>(); + VectorF16<16, 1> fLenghtE = fLenght.template Shuffle<{{4,4,5,5,6,6,7,7,12,12,13,13,14,14,15,15}}>(); return { _mm256_mul_ph(A.v, fLenghtA.v), _mm256_mul_ph(E.v, fLenghtE.v), }; } else { - VectorF16<1, 32> lenght = Length(A, E); + VectorF16<1, 32> lenght = LengthNoShuffle(A, E); constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; __m512h one = _mm512_loadu_ph(oneArr); VectorF16<32, 1> fLenght(_mm512_div_ph(one, lenght.v)); - VectorF16<32, 1> fLenghtA = fLenght.template Shuffle<{{0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,9,9,10,10,11,11,12,12,13,13,14,14,15,15}}>(); - VectorF16<32, 1> fLenghtE = fLenght.template Shuffle<{{16,16,17,17,18,18,19,19,20,20,21,21,22,22,23,23,24,24,25,25,26,26,27,27,28,28,29,29,30,30,31,31}}>(); + VectorF16<32, 1> fLenghtA = fLenght.template Shuffle<{{0,0,1,1,2,2,3,3,8,8,9,9,10,10,11,11,16,16,17,17,18,18,19,19,24,24,25,25,26,26,27,27}}>(); + VectorF16<32, 1> fLenghtE = fLenght.template Shuffle<{{4,4,5,5,6,6,7,7,12,12,13,13,14,14,15,15,20,20,21,21,22,22,23,23,28,28,29,29,30,30,31,31}}>(); return { _mm512_mul_ph(A.v, fLenghtA.v), @@ -1108,6 +1084,163 @@ namespace Crafter { VectorF16 G0, VectorF16 G1, VectorF16 H0, VectorF16 H1 ) requires(Len == 8 && Packing*Len == Alignment) { + if constexpr(std::is_same_v) { + return DotNoShuffle(A0, A1, B0, B1, C0, C1, D0, D1, E0, E1, G0, G1, H0, H1); + } else if constexpr(std::is_same_v) { + VectorF16<16, 1> vec(DotNoShuffle(A0, A1, B0, B1, C0, C1, D0, D1, E0, E1, G0, G1, H0, H1).v); + vec = vec.template Shuffle<{{0,1,2,3,8,9,10,11,4,5,6,7,12,13,14,15}}>(); + return vec.v; + } else { + VectorF16<32, 1> vec(DotNoShuffle(A0, A1, B0, B1, C0, C1, D0, D1, E0, E1, G0, G1, H0, H1).v); + vec = vec.template Shuffle<{{0,1,2,3,8,9,10,11,16,17,18,19,24,25,26,27,4,5,6,7,12,13,14,15,20,21,22,23,28,29,30,31}}>(); + return vec.v; + } + } + + constexpr static VectorF16<1, Packing*4> Dot( + VectorF16 A0, VectorF16 A1, + VectorF16 C0, VectorF16 C1, + VectorF16 E0, VectorF16 E1, + VectorF16 G0, VectorF16 G1 + ) requires(Len == 4 && Packing*Len == Alignment) { + if constexpr(std::is_same_v) { + return DotNoShuffle(A0, A1, E0, E1, C0, C1, G0, G1); + } else if constexpr(std::is_same_v) { + VectorF16<16, 1> vec(DotNoShuffle(A0, A1, C0, C1, E0, E1, G0, G1).v); + vec = vec.template Shuffle<{{ + 0,1,8,9, + 4,5,12,13, + 2,3,10,11, + 6,7,14,15 + }}>(); + return vec.v; + } else { + VectorF16<32, 1> vec(DotNoShuffle(A0, A1, C0, C1, E0, E1, G0, G1).v); + vec = vec.template Shuffle<{{ + 0,1,8,9, + 16,17,24,25, + + 4,5,12,13, + 20,21,28,29, + + 2,3,10,11, + 18,19,24,25, + + 6,7,14,15, + 22,23,30,31 + }}>(); + return vec.v; + } + } + + constexpr static VectorF16<1, Packing*2> Dot( + VectorF16 A0, VectorF16 A1, + VectorF16 E0, VectorF16 E1 + ) requires(Len == 2 && Packing*Len == Alignment) { + if constexpr(std::is_same_v) { + return DotNoShuffle(A0, A1, E0, E1); + } else if constexpr(std::is_same_v) { + VectorF16<16, 1> vec(DotNoShuffle(A0, A1, E0, E1).v); + vec = vec.template Shuffle<{{0,1,2,3,8,9,10,11,4,5,6,7,12,13,14,15}}>(); + return vec.v; + } else { + VectorF16<32, 1> vec(DotNoShuffle(A0, A1, E0, E1).v); + vec = vec.template Shuffle<{{0,1,2,3,8,9,10,11,16,17,18,19,24,25,26,27,4,5,6,7,12,13,14,15,20,21,22,23,28,29,30,31}}>(); + return vec.v; + } + } + + + private: + constexpr static VectorF16<1, Packing*8> LengthNoShuffle( + VectorF16 A, + VectorF16 B, + VectorF16 C, + VectorF16 D, + VectorF16 E, + VectorF16 F, + VectorF16 G, + VectorF16 H + ) requires(Len == 8 && Packing*Len == Alignment) { + VectorF16<1, Packing*8> lenghtSq = LengthSqNoShuffle(A, B, C, D, E, F, G, H); + if constexpr(std::is_same_v) { + return VectorF16<1, Packing*8>(_mm_sqrt_ph(lenghtSq.v)); + } else if constexpr(std::is_same_v) { + return VectorF16<1, Packing*8>(_mm256_sqrt_ph(lenghtSq.v)); + } else { + return VectorF16<1, Packing*8>(_mm512_sqrt_ph(lenghtSq.v)); + } + } + + constexpr static VectorF16<1, Packing*4> LengthNoShuffle( + VectorF16 A, + VectorF16 C, + VectorF16 E, + VectorF16 G + ) requires(Len == 4 && Packing*Len == Alignment) { + VectorF16<1, Packing*4> lenghtSq = LengthSqNoShuffle(A, C, E, G); + if constexpr(std::is_same_v) { + return VectorF16<1, Packing*4>(_mm_sqrt_ph(lenghtSq.v)); + } else if constexpr(std::is_same_v) { + return VectorF16<1, Packing*4>(_mm256_sqrt_ph(lenghtSq.v)); + } else { + return VectorF16<1, Packing*4>(_mm512_sqrt_ph(lenghtSq.v)); + } + } + + constexpr static VectorF16<1, Packing*2> LengthNoShuffle( + VectorF16 A, + VectorF16 E + ) requires(Len == 2 && Packing*Len == Alignment) { + VectorF16<1, Packing*2> lenghtSq = LengthSqNoShuffle(A, E); + if constexpr(std::is_same_v) { + return VectorF16<1, Packing*2>(_mm_sqrt_ph(lenghtSq.v)); + } else if constexpr(std::is_same_v) { + return VectorF16<1, Packing*2>(_mm256_sqrt_ph(lenghtSq.v)); + } else { + return VectorF16<1, Packing*2>(_mm512_sqrt_ph(lenghtSq.v)); + } + } + + constexpr static VectorF16<1, Packing*8> LengthSqNoShuffle( + VectorF16 A, + VectorF16 B, + VectorF16 C, + VectorF16 D, + VectorF16 E, + VectorF16 F, + VectorF16 G, + VectorF16 H + ) requires(Len == 8 && Packing*Len == Alignment) { + return DotNoShuffle(A, A, B, B, C, C, D, D, E, E, F, F, G, G, H, H); + } + + constexpr static VectorF16<1, Packing*4> LengthSqNoShuffle( + VectorF16 A, + VectorF16 C, + VectorF16 E, + VectorF16 G + ) requires(Len == 4 && Packing*Len == Alignment) { + return DotNoShuffle(A, A, C, C, E, E, G, G); + } + + constexpr static VectorF16<1, Packing*2> LengthSqNoShuffle( + VectorF16 A, + VectorF16 E + ) requires(Len == 2 && Packing*Len == Alignment) { + return DotNoShuffle(A, A, E, E); + } + + constexpr static VectorF16<1, Packing*8> DotNoShuffle( + VectorF16 A0, VectorF16 A1, + VectorF16 B0, VectorF16 B1, + VectorF16 C0, VectorF16 C1, + VectorF16 D0, VectorF16 D1, + VectorF16 E0, VectorF16 E1, + VectorF16 F0, VectorF16 F1, + VectorF16 G0, VectorF16 G1, + VectorF16 H0, VectorF16 H1 + ) requires(Len == 8 && Packing*Len == Alignment) { if constexpr(std::is_same_v) { __m128h mulA = _mm_mul_ph(A0.v, A1.v); __m128h mulB = _mm_mul_ph(B0.v, B1.v); @@ -1270,7 +1403,7 @@ namespace Crafter { } } - constexpr static VectorF16<1, Packing*4> Dot( + constexpr static VectorF16<1, Packing*4> DotNoShuffle( VectorF16 A0, VectorF16 A1, VectorF16 C0, VectorF16 C1, VectorF16 E0, VectorF16 E1, @@ -1279,20 +1412,22 @@ namespace Crafter { if constexpr(std::is_same_v) { __m128h mulA = _mm_mul_ph(A0.v, A1.v); __m128h mulC = _mm_mul_ph(C0.v, C1.v); + __m128i row12Temp1 = _mm_unpacklo_epi16(_mm_castph_si128(mulA), _mm_castph_si128(mulC)); // A1 C1 A2 C2 A3 C3 A4 C4 __m128i row34Temp1 = _mm_unpackhi_epi16(_mm_castph_si128(mulA), _mm_castph_si128(mulC)); // B1 D1 B2 D2 B3 D3 B4 D4 __m128i row1TempTemp1 = row12Temp1; - __m128i row5TempTemp1 = row34Temp1; __m128h mulE = _mm_mul_ph(E0.v, E1.v); __m128h mulG = _mm_mul_ph(G0.v, G1.v); + __m128i row12Temp2 = _mm_unpacklo_epi16(_mm_castph_si128(mulE), _mm_castph_si128(mulG)); // E1 G1 E2 G2 E3 G3 E4 G4 + __m128i row12Temp2Temp = row12Temp2; __m128i row34Temp2 = _mm_unpackhi_epi16(_mm_castph_si128(mulE), _mm_castph_si128(mulG)); // F1 H1 F2 H2 F3 H3 F4 H4 row12Temp1 = _mm_unpacklo_epi16(row12Temp1, row12Temp2); // A1 E1 C1 G1 A2 E2 C2 G2 row12Temp2 = _mm_unpacklo_epi16(row34Temp1, row34Temp2); // B1 F1 D1 H1 B2 F2 D2 H2 - row34Temp1 = _mm_unpackhi_epi16(row1TempTemp1, row34Temp1); // A3 E3 C3 G3 A4 E4 C4 G4 - row34Temp2 = _mm_unpackhi_epi16(row5TempTemp1, row34Temp2); // B3 F3 D3 H3 B4 F4 D4 H4 + row34Temp2 = _mm_unpackhi_epi16(row34Temp1, row34Temp2); // B3 F3 D3 H3 B4 F4 D4 H4 + row34Temp1 = _mm_unpackhi_epi16(row1TempTemp1, row12Temp2Temp); // A3 E3 C3 G3 A4 E4 C4 G4 __m128h row1 = _mm_castsi128_ph(_mm_unpacklo_epi16(row12Temp1, row12Temp2));// A1 B1 E1 F1 C1 D1 G1 H1 __m128h row2 = _mm_castsi128_ph(_mm_unpackhi_epi16(row12Temp1, row12Temp2));// A2 B2 E2 F2 C2 D2 G2 H2 @@ -1307,20 +1442,22 @@ namespace Crafter { } else if constexpr(std::is_same_v) { __m256h mulA = _mm256_mul_ph(A0.v, A1.v); __m256h mulC = _mm256_mul_ph(C0.v, C1.v); + __m256i row12Temp1 = _mm256_unpacklo_epi16(_mm256_castph_si256(mulA), _mm256_castph_si256(mulC)); // A1 C1 A2 C2 A3 C3 A4 C4 __m256i row34Temp1 = _mm256_unpackhi_epi16(_mm256_castph_si256(mulA), _mm256_castph_si256(mulC)); // B1 D1 B2 D2 B3 D3 B4 D4 __m256i row1TempTemp1 = row12Temp1; - __m256i row5TempTemp1 = row34Temp1; __m256h mulE = _mm256_mul_ph(E0.v, E1.v); __m256h mulG = _mm256_mul_ph(G0.v, G1.v); + __m256i row12Temp2 = _mm256_unpacklo_epi16(_mm256_castph_si256(mulE), _mm256_castph_si256(mulG)); // E1 G1 E2 G2 E3 G3 E4 G4 + __m256i row12Temp2Temp = row12Temp2; __m256i row34Temp2 = _mm256_unpackhi_epi16(_mm256_castph_si256(mulE), _mm256_castph_si256(mulG)); // F1 H1 F2 H2 F3 H3 F4 H4 row12Temp1 = _mm256_unpacklo_epi16(row12Temp1, row12Temp2); // A1 E1 C1 G1 A2 E2 C2 G2 row12Temp2 = _mm256_unpacklo_epi16(row34Temp1, row34Temp2); // B1 F1 D1 H1 B2 F2 D2 H2 - row34Temp1 = _mm256_unpackhi_epi16(row1TempTemp1, row34Temp1); // A3 E3 C3 G3 A4 E4 C4 G4 - row34Temp2 = _mm256_unpackhi_epi16(row5TempTemp1, row34Temp2); // B3 F3 D3 H3 B4 F4 D4 H4 + row34Temp2 = _mm256_unpackhi_epi16(row34Temp1, row34Temp2); // B3 F3 D3 H3 B4 F4 D4 H4 + row34Temp1 = _mm256_unpackhi_epi16(row1TempTemp1, row12Temp2Temp); // A3 E3 C3 G3 A4 E4 C4 G4 __m256h row1 = _mm256_castsi256_ph(_mm256_unpacklo_epi16(row12Temp1, row12Temp2));// A1 B1 E1 F1 C1 D1 G1 H1 __m256h row2 = _mm256_castsi256_ph(_mm256_unpackhi_epi16(row12Temp1, row12Temp2));// A2 B2 E2 F2 C2 D2 G2 H2 @@ -1335,20 +1472,22 @@ namespace Crafter { } else { __m512h mulA = _mm512_mul_ph(A0.v, A1.v); __m512h mulC = _mm512_mul_ph(C0.v, C1.v); + __m512i row12Temp1 = _mm512_unpacklo_epi16(_mm512_castph_si512(mulA), _mm512_castph_si512(mulC)); // A1 C1 A2 C2 A3 C3 A4 C4 __m512i row34Temp1 = _mm512_unpackhi_epi16(_mm512_castph_si512(mulA), _mm512_castph_si512(mulC)); // B1 D1 B2 D2 B3 D3 B4 D4 __m512i row1TempTemp1 = row12Temp1; - __m512i row5TempTemp1 = row34Temp1; __m512h mulE = _mm512_mul_ph(E0.v, E1.v); __m512h mulG = _mm512_mul_ph(G0.v, G1.v); + __m512i row12Temp2 = _mm512_unpacklo_epi16(_mm512_castph_si512(mulE), _mm512_castph_si512(mulG)); // E1 G1 E2 G2 E3 G3 E4 G4 + __m512i row12Temp2Temp = row12Temp2; __m512i row34Temp2 = _mm512_unpackhi_epi16(_mm512_castph_si512(mulE), _mm512_castph_si512(mulG)); // F1 H1 F2 H2 F3 H3 F4 H4 row12Temp1 = _mm512_unpacklo_epi16(row12Temp1, row12Temp2); // A1 E1 C1 G1 A2 E2 C2 G2 row12Temp2 = _mm512_unpacklo_epi16(row34Temp1, row34Temp2); // B1 F1 D1 H1 B2 F2 D2 H2 - row34Temp1 = _mm512_unpackhi_epi16(row1TempTemp1, row34Temp1); // A3 E3 C3 G3 A4 E4 C4 G4 - row34Temp2 = _mm512_unpackhi_epi16(row5TempTemp1, row34Temp2); // B3 F3 D3 H3 B4 F4 D4 H4 + row34Temp2 = _mm512_unpackhi_epi16(row34Temp1, row34Temp2); // B3 F3 D3 H3 B4 F4 D4 H4 + row34Temp1 = _mm512_unpackhi_epi16(row1TempTemp1, row12Temp2Temp); // A3 E3 C3 G3 A4 E4 C4 G4 __m512h row1 = _mm512_castsi512_ph(_mm512_unpacklo_epi16(row12Temp1, row12Temp2));// A1 B1 E1 F1 C1 D1 G1 H1 __m512h row2 = _mm512_castsi512_ph(_mm512_unpackhi_epi16(row12Temp1, row12Temp2));// A2 B2 E2 F2 C2 D2 G2 H2 @@ -1363,7 +1502,7 @@ namespace Crafter { } } - constexpr static VectorF16<1, Packing*2> Dot( + constexpr static VectorF16<1, Packing*2> DotNoShuffle( VectorF16 A0, VectorF16 A1, VectorF16 E0, VectorF16 E1 ) requires(Len == 2 && Packing*Len == Alignment) { @@ -1397,10 +1536,7 @@ namespace Crafter { __m256h row2 = _mm256_castsi256_ph(_mm256_unpackhi_epi16(row12Temp1, row12Temp2));// A2 B2 C2 D2 E2 F2 G2 H2 __m256h result = _mm256_add_ph(row1, row2); - VectorF16<16, 1> vec(result); - vec = vec.template Shuffle<{{0,1,2,3,8,9,10,11,4,5,6,7,12,13,14,15}}>(); - - return VectorF16<1, 16>(vec.v); + return result; } else { __m512h mulA = _mm512_mul_ph(A0.v, A1.v); __m512h mulE = _mm512_mul_ph(E0.v, E1.v); @@ -1414,12 +1550,11 @@ namespace Crafter { __m512h row1 = _mm512_castsi512_ph(_mm512_unpacklo_epi16(row12Temp1, row12Temp2));// A1 B1 C1 D1 E1 F1 G1 H1 __m512h row2 = _mm512_castsi512_ph(_mm512_unpackhi_epi16(row12Temp1, row12Temp2));// A2 B2 C2 D2 E2 F2 G2 H2 __m512h result = _mm512_add_ph(row1, row2); - - VectorF16<32, 1> vec(result); - vec = vec.template Shuffle<{{0,1,2,3,8,9,10,11,16,17,18,19,24,25,26,27,4,5,6,7,12,13,14,15,20,21,22,23,28,29,30,31}}>(); - return VectorF16<1, 32>(vec.v); + + return result; } } + public: template ShuffleValues> constexpr static VectorF16 Blend(VectorF16 a, VectorF16 b) { diff --git a/tests/Vector.cpp b/tests/Vector.cpp index 3566f05..4068b01 100644 --- a/tests/Vector.cpp +++ b/tests/Vector.cpp @@ -412,6 +412,51 @@ std::string* TestAllCombinations() { } } + if constexpr(Len == 4 && Packing*Len == VectorType::Alignment) { + { + VectorType vecA(floats); + VectorType vecC = vecA * 2; + VectorType vecE = vecA * 3; + VectorType vecG = vecA * 4; + VectorType<1, Packing*4> result = VectorType::Length(vecA, vecC, vecE, vecG); + Vector::Alignment> stored = result.Store(); + + if (!FloatEquals(stored.v[0], expectedLength[0])) { + return new std::string(std::format("Length 4 vecA test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0], (float)stored.v[0])); + } + + if (!FloatEquals(stored.v[Packing], expectedLength[0] * 2)) { + return new std::string(std::format("Length 4 vecC test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0] * 2, (float)stored.v[Packing])); + } + + if (!FloatEquals(stored.v[Packing*2], expectedLength[0] * 3)) { + return new std::string(std::format("Length 4 vecE test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0] * 3, (float)stored.v[Packing*2])); + } + + if (!FloatEquals(stored.v[Packing*3], expectedLength[0] * 4)) { + return new std::string(std::format("Length 4 vecG test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0] * 4, (float)stored.v[Packing*3])); + } + } + + { + VectorType vecA(floats); + VectorType vecC = vecA * 2; + VectorType vecE = vecA * 3; + VectorType vecG = vecA * 4; + auto result = VectorType::Normalize(vecA, vecC, vecE, vecG); + VectorType<1, Packing*4> result2 = VectorType::Length(std::get<0>(result), std::get<1>(result), std::get<2>(result), std::get<3>(result)); + Vector::Alignment> stored = result2.Store(); + + //std::println("{}", stored); + + for(std::uint8_t i = 0; i < Len*Packing; i++) { + if (!FloatEquals(stored.v[i], T(1))) { + return new std::string(std::format("Normalize {} test failed at Len={} Packing={} Expected: {}, Got: {}", i, Len, Packing, 1, (float)stored.v[i])); + } + } + } + } + return TestAllCombinations(); } }