diff --git a/interfaces/Crafter.Math-VectorF16.cppm b/interfaces/Crafter.Math-VectorF16.cppm index c334672..20e17ab 100755 --- a/interfaces/Crafter.Math-VectorF16.cppm +++ b/interfaces/Crafter.Math-VectorF16.cppm @@ -148,7 +148,7 @@ namespace Crafter { } - constexpr void operator+=(VectorF16 b) const { + constexpr void operator+=(VectorF16 b) { if constexpr(std::is_same_v) { v = _mm_add_ph(v, b.v); } else if constexpr(std::is_same_v) { @@ -158,7 +158,7 @@ namespace Crafter { } } - constexpr void operator-=(VectorF16 b) const { + constexpr void operator-=(VectorF16 b) { if constexpr(std::is_same_v) { v = _mm_sub_ph(v, b.v); } else if constexpr(std::is_same_v) { @@ -168,7 +168,7 @@ namespace Crafter { } } - constexpr void operator*=(VectorF16 b) const { + constexpr void operator*=(VectorF16 b) { if constexpr(std::is_same_v) { v = _mm_mul_ph(v, b.v); } else if constexpr(std::is_same_v) { @@ -178,7 +178,7 @@ namespace Crafter { } } - constexpr void operator/=(VectorF16 b) const { + constexpr void operator/=(VectorF16 b) { if constexpr(std::is_same_v) { v = _mm_div_ph(v, b.v); } else if constexpr(std::is_same_v) { @@ -188,48 +188,48 @@ namespace Crafter { } } - constexpr VectorF16 operator+(_Float16 b) const { + constexpr VectorF16 operator+(_Float16 b) { VectorF16 vB(b); - return this + vB; + return *this + vB; } - constexpr VectorF16 operator-(_Float16 b) const { + constexpr VectorF16 operator-(_Float16 b) { VectorF16 vB(b); - return this - vB; + return *this - vB; } - constexpr VectorF16 operator*(_Float16 b) const { + constexpr VectorF16 operator*(_Float16 b) { VectorF16 vB(b); - return this * vB; + return *this * vB; } - constexpr VectorF16 operator/(_Float16 b) const { + constexpr VectorF16 operator/(_Float16 b) { VectorF16 vB(b); - return this / vB; + return *this / vB; } - constexpr void operator+=(_Float16 b) const { + constexpr void operator+=(_Float16 b) { VectorF16 vB(b); - this += vB; + *this += vB; } - constexpr void operator-=(_Float16 b) const { + constexpr void operator-=(_Float16 b) { VectorF16 vB(b); - this -= vB; + *this -= vB; } - constexpr void operator*=(_Float16 b) const { + constexpr void operator*=(_Float16 b) { VectorF16 vB(b); - this *= vB; + *this *= vB; } - constexpr void operator/=(_Float16 b) const { + constexpr void operator/=(_Float16 b) { VectorF16 vB(b); - this /= vB; + *this /= vB; } constexpr VectorF16 operator-(){ - return Negate(); + return Negate(); } constexpr bool operator==(VectorF16 b) const { @@ -281,28 +281,88 @@ namespace Crafter { } constexpr VectorF16 Cos() { - if constexpr(std::is_same_v) { - return VectorF16(_mm_cos_ph(v)); - } else if constexpr(std::is_same_v) { - return VectorF16(_mm256_cos_ph(v)); + if constexpr (std::is_same_v) { + __m256 wide = _mm256_cvtph_ps(_mm_castph_si128(v)); + wide = cos_f32x8(wide); + return VectorF16( + _mm_castsi128_ph(_mm256_cvtps_ph(wide, _MM_FROUND_TO_NEAREST_INT))); + + } else if constexpr (std::is_same_v) { + __m512 wide = _mm512_cvtph_ps(_mm256_castph_si256(v)); + wide = cos_f32x16(wide); + return VectorF16( + _mm256_castsi256_ph(_mm512_cvtps_ph(wide, _MM_FROUND_TO_NEAREST_INT))); + } else { - return VectorF16(_mm512_cos_ph(v)); + __m256i lo = _mm512_castsi512_si256(_mm512_castph_si512(v)); + __m256i hi = _mm512_extracti64x4_epi64(_mm512_castph_si512(v), 1); + __m256i lo_ph = _mm512_cvtps_ph(cos_f32x16(_mm512_cvtph_ps(lo)), _MM_FROUND_TO_NEAREST_INT); + __m256i hi_ph = _mm512_cvtps_ph(cos_f32x16(_mm512_cvtph_ps(hi)), _MM_FROUND_TO_NEAREST_INT); + return VectorF16( + _mm512_castsi512_ph(_mm512_inserti64x4(_mm512_castsi256_si512(lo_ph), hi_ph, 1))); } - } + } constexpr VectorF16 Sin() { - if constexpr(std::is_same_v) { - return VectorF16(_mm_sin_ph(v)); - } else if constexpr(std::is_same_v) { - return VectorF16(_mm256_sin_ph(v)); + if constexpr (std::is_same_v) { + __m256 wide = _mm256_cvtph_ps(_mm_castph_si128(v)); + wide = sin_f32x8(wide); + return VectorF16(_mm_castsi128_ph(_mm256_cvtps_ph(wide, _MM_FROUND_TO_NEAREST_INT))); + + } else if constexpr (std::is_same_v) { + __m512 wide = _mm512_cvtph_ps(_mm256_castph_si256(v)); + wide = sin_f32x16(wide); + return VectorF16(_mm256_castsi256_ph(_mm512_cvtps_ph(wide, _MM_FROUND_TO_NEAREST_INT))); + } else { - return VectorF16(_mm512_sin_ph(v)); + __m256i lo = _mm512_castsi512_si256(_mm512_castph_si512(v)); + __m256i hi = _mm512_extracti64x4_epi64(_mm512_castph_si512(v), 1); + __m256i lo_ph = _mm512_cvtps_ph(sin_f32x16(_mm512_cvtph_ps(lo)), _MM_FROUND_TO_NEAREST_INT); + __m256i hi_ph = _mm512_cvtps_ph(sin_f32x16(_mm512_cvtph_ps(hi)), _MM_FROUND_TO_NEAREST_INT); + return VectorF16(_mm512_castsi512_ph(_mm512_inserti64x4(_mm512_castsi256_si512(lo_ph), hi_ph, 1))); } - } + } + + std::tuple, VectorF16> SinCos() { + if constexpr (std::is_same_v) { + __m256 wide = _mm256_cvtph_ps(_mm_castph_si128(v)); + __m256 s, c; + sincos_f32x8(wide, s, c); + return { + VectorF16(_mm_castsi128_ph(_mm256_cvtps_ph(s, _MM_FROUND_TO_NEAREST_INT))), + VectorF16(_mm_castsi128_ph(_mm256_cvtps_ph(c, _MM_FROUND_TO_NEAREST_INT))) + }; + + } else if constexpr (std::is_same_v) { + __m512 wide = _mm512_cvtph_ps(_mm256_castph_si256(v)); + __m512 s, c; + sincos_f32x16(wide, s, c); + return { + VectorF16(_mm256_castsi256_ph(_mm512_cvtps_ph(s, _MM_FROUND_TO_NEAREST_INT))), + VectorF16(_mm256_castsi256_ph(_mm512_cvtps_ph(c, _MM_FROUND_TO_NEAREST_INT))) + }; + + } else { + __m256i lo = _mm512_castsi512_si256(_mm512_castph_si512(v)); + __m256i hi = _mm512_extracti64x4_epi64(_mm512_castph_si512(v), 1); + + __m512 s_lo, c_lo, s_hi, c_hi; + sincos_f32x16(_mm512_cvtph_ps(lo), s_lo, c_lo); + sincos_f32x16(_mm512_cvtph_ps(hi), s_hi, c_hi); + + auto pack = [](__m256i lo_ph, __m256i hi_ph) { + return _mm512_castsi512_ph(_mm512_inserti64x4(_mm512_castsi256_si512(lo_ph), hi_ph, 1)); + }; + return { + VectorF16(pack(_mm512_cvtps_ph(s_lo, _MM_FROUND_TO_NEAREST_INT), _mm512_cvtps_ph(s_hi, _MM_FROUND_TO_NEAREST_INT))), + VectorF16(pack( _mm512_cvtps_ph(c_lo, _MM_FROUND_TO_NEAREST_INT), _mm512_cvtps_ph(c_hi, _MM_FROUND_TO_NEAREST_INT))) + }; + } + } template values> - constexpr std::array Negate() { - std::array mask = GetShuffleMaskEpi32(); + constexpr VectorF16 Negate() { + std::array mask = GetNegateMask(); if constexpr(std::is_same_v) { return VectorF16(_mm_castsi128_ph(_mm_xor_si128(_mm_castph_si128(v), _mm_loadu_epi16(mask.data())))); } else if constexpr(std::is_same_v) { @@ -341,21 +401,21 @@ namespace Crafter { static constexpr VectorF16 MulitplyAdd(VectorF16 a, VectorF16 b, VectorF16 add) { if constexpr(std::is_same_v) { - return VectorF16(_mm_fmadd_ph(a.v, b.v, add)); + return VectorF16(_mm_fmadd_ph(a.v, b.v, add.v)); } else if constexpr(std::is_same_v) { - return VectorF16(_mm256_fmadd_ph(a.v, b.v, add)); + return VectorF16(_mm256_fmadd_ph(a.v, b.v, add.v)); } else { - return VectorF16(_mm512_fmadd_ph(a.v, b.v, add)); + return VectorF16(_mm512_fmadd_ph(a.v, b.v, add.v)); } } static constexpr VectorF16 MulitplySub(VectorF16 a, VectorF16 b, VectorF16 sub) { if constexpr(std::is_same_v) { - return VectorF16(_mm_fmsub_ph(a.v, b.v, sub)); + return VectorF16(_mm_fmsub_ph(a.v, b.v, sub.v)); } else if constexpr(std::is_same_v) { - return VectorF16(_mm256_fmsub_ph(a.v, b.v, sub)); + return VectorF16(_mm256_fmsub_ph(a.v, b.v, sub.v)); } else { - return VectorF16(_mm512_fmsub_ph(a.v, b.v, sub)); + return VectorF16(_mm512_fmsub_ph(a.v, b.v, sub.v)); } } @@ -1227,7 +1287,7 @@ namespace Crafter { return shuffleMask; } - consteval std::array GetAllTrue() { + static consteval std::array GetAllTrue() { std::array arr{}; arr.fill(true); return arr; @@ -1288,6 +1348,216 @@ namespace Crafter { } return mask; } + + static constexpr float two_over_pi = 0.6366197723675814f; + static constexpr float pi_over_2_hi = 1.5707963267341256f; + static constexpr float pi_over_2_lo = 6.077100506506192e-11f; + + // Cos polynomial on [-pi/4, pi/4]: c0 + c2*r^2 + c4*r^4 + ... + static constexpr float c0 = 1.0f; + static constexpr float c2 = -0.4999999642372f; + static constexpr float c4 = 0.0416666418707f; + static constexpr float c6 = -0.0013888397720f; + static constexpr float c8 = 0.0000248015873f; + static constexpr float c10 = -0.0000002752258f; + + // Sin polynomial on [-pi/4, pi/4]: r * (1 + s1*r^2 + s3*r^4 + ...) + static constexpr float s1 = -0.1666666641831f; + static constexpr float s3 = 0.0083333293858f; + static constexpr float s5 = -0.0001984090955f; + static constexpr float s7 = 0.0000027526372f; + static constexpr float s9 = -0.0000000239013f; + + // Reduce |x| into [-pi/4, pi/4], return reduced value and quadrant + constexpr void range_reduce_f32x8(__m256 ax, __m256& r, __m256& r2, __m256i& q) { + __m256 fq = _mm256_round_ps(_mm256_mul_ps(ax, _mm256_set1_ps(two_over_pi)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + q = _mm256_cvtps_epi32(fq); + r = _mm256_sub_ps(ax, _mm256_mul_ps(fq, _mm256_set1_ps(pi_over_2_hi))); + r = _mm256_sub_ps(r, _mm256_mul_ps(fq, _mm256_set1_ps(pi_over_2_lo))); + r2 = _mm256_mul_ps(r, r); + } + + + // cos(x): use cos_poly when q even, sin_poly when q odd; negate if (q+1)&2 + constexpr __m256 cos_f32x8(__m256 x) { + const __m256 sign_mask = _mm256_set1_ps(-0.0f); + __m256 ax = _mm256_andnot_ps(sign_mask, x); + + __m256 r, r2; __m256i q; + range_reduce_f32x8(ax, r, r2, q); + + __m256 cos_r, sin_r; + sincos_poly_f32x8(r, r2, cos_r, sin_r); + + __m256i odd = _mm256_and_si256(q, _mm256_set1_epi32(1)); + __m256 use_sin = _mm256_castsi256_ps(_mm256_cmpeq_epi32(odd, _mm256_set1_epi32(1))); + __m256 result = _mm256_blendv_ps(cos_r, sin_r, use_sin); + + __m256i need_neg = _mm256_and_si256( + _mm256_add_epi32(q, _mm256_set1_epi32(1)), _mm256_set1_epi32(2)); + __m256 neg_mask = _mm256_castsi256_ps(_mm256_slli_epi32(need_neg, 30)); + return _mm256_xor_ps(result, neg_mask); + } + + constexpr void sincos_poly_f32x8(__m256 r, __m256 r2, __m256& cos_r, __m256& sin_r) { + cos_r = _mm256_fmadd_ps(_mm256_set1_ps(c10), r2, _mm256_set1_ps(c8)); + cos_r = _mm256_fmadd_ps(cos_r, r2, _mm256_set1_ps(c6)); + cos_r = _mm256_fmadd_ps(cos_r, r2, _mm256_set1_ps(c4)); + cos_r = _mm256_fmadd_ps(cos_r, r2, _mm256_set1_ps(c2)); + cos_r = _mm256_fmadd_ps(cos_r, r2, _mm256_set1_ps(c0)); + + sin_r = _mm256_fmadd_ps(_mm256_set1_ps(s9), r2, _mm256_set1_ps(s7)); + sin_r = _mm256_fmadd_ps(sin_r, r2, _mm256_set1_ps(s5)); + sin_r = _mm256_fmadd_ps(sin_r, r2, _mm256_set1_ps(s3)); + sin_r = _mm256_fmadd_ps(sin_r, r2, _mm256_set1_ps(s1)); + sin_r = _mm256_fmadd_ps(sin_r, r2, _mm256_set1_ps(1.0f)); + sin_r = _mm256_mul_ps(sin_r, r); + } + + // sin(x): use sin_poly when q even, cos_poly when q odd; negate if q&2; respect input sign + constexpr __m256 sin_f32x8(__m256 x) { + const __m256 sign_mask = _mm256_set1_ps(-0.0f); + __m256 x_sign = _mm256_and_ps(x, sign_mask); + __m256 ax = _mm256_andnot_ps(sign_mask, x); + + __m256 r, r2; __m256i q; + range_reduce_f32x8(ax, r, r2, q); + + __m256 cos_r, sin_r; + sincos_poly_f32x8(r, r2, cos_r, sin_r); + + __m256i odd = _mm256_and_si256(q, _mm256_set1_epi32(1)); + __m256 use_cos = _mm256_castsi256_ps(_mm256_cmpeq_epi32(odd, _mm256_set1_epi32(1))); + __m256 result = _mm256_blendv_ps(sin_r, cos_r, use_cos); + + __m256i need_neg = _mm256_and_si256(q, _mm256_set1_epi32(2)); + __m256 neg_mask = _mm256_castsi256_ps(_mm256_slli_epi32(need_neg, 30)); + result = _mm256_xor_ps(result, neg_mask); + + // Apply original sign of x + return _mm256_xor_ps(result, x_sign); + } + + // --- 512-bit helpers --- + + constexpr void range_reduce_f32x16(__m512 ax, __m512& r, __m512& r2, __m512i& q) { + __m512 fq = _mm512_roundscale_ps(_mm512_mul_ps(ax, _mm512_set1_ps(two_over_pi)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + q = _mm512_cvtps_epi32(fq); + r = _mm512_sub_ps(ax, _mm512_mul_ps(fq, _mm512_set1_ps(pi_over_2_hi))); + r = _mm512_sub_ps(r, _mm512_mul_ps(fq, _mm512_set1_ps(pi_over_2_lo))); + r2 = _mm512_mul_ps(r, r); + } + + constexpr void sincos_poly_f32x16(__m512 r, __m512 r2, __m512& cos_r, __m512& sin_r) { + cos_r = _mm512_fmadd_ps(_mm512_set1_ps(c10), r2, _mm512_set1_ps(c8)); + cos_r = _mm512_fmadd_ps(cos_r, r2, _mm512_set1_ps(c6)); + cos_r = _mm512_fmadd_ps(cos_r, r2, _mm512_set1_ps(c4)); + cos_r = _mm512_fmadd_ps(cos_r, r2, _mm512_set1_ps(c2)); + cos_r = _mm512_fmadd_ps(cos_r, r2, _mm512_set1_ps(c0)); + + sin_r = _mm512_fmadd_ps(_mm512_set1_ps(s9), r2, _mm512_set1_ps(s7)); + sin_r = _mm512_fmadd_ps(sin_r, r2, _mm512_set1_ps(s5)); + sin_r = _mm512_fmadd_ps(sin_r, r2, _mm512_set1_ps(s3)); + sin_r = _mm512_fmadd_ps(sin_r, r2, _mm512_set1_ps(s1)); + sin_r = _mm512_fmadd_ps(sin_r, r2, _mm512_set1_ps(1.0f)); + sin_r = _mm512_mul_ps(sin_r, r); + } + + constexpr __m512 cos_f32x16(__m512 x) { + __m512 ax = _mm512_abs_ps(x); + + __m512 r, r2; __m512i q; + range_reduce_f32x16(ax, r, r2, q); + + __m512 cos_r, sin_r; + sincos_poly_f32x16(r, r2, cos_r, sin_r); + + __mmask16 odd = _mm512_test_epi32_mask(q, _mm512_set1_epi32(1)); + __m512 result = _mm512_mask_blend_ps(odd, cos_r, sin_r); + + __m512i need_neg = _mm512_and_si512( + _mm512_add_epi32(q, _mm512_set1_epi32(1)), _mm512_set1_epi32(2)); + __m512 neg_mask = _mm512_castsi512_ps(_mm512_slli_epi32(need_neg, 30)); + return _mm512_xor_ps(result, neg_mask); + } + + constexpr __m512 sin_f32x16(__m512 x) { + __m512 x_sign = _mm512_and_ps(x, _mm512_set1_ps(-0.0f)); + __m512 ax = _mm512_abs_ps(x); + + __m512 r, r2; __m512i q; + range_reduce_f32x16(ax, r, r2, q); + + __m512 cos_r, sin_r; + sincos_poly_f32x16(r, r2, cos_r, sin_r); + + __mmask16 odd = _mm512_test_epi32_mask(q, _mm512_set1_epi32(1)); + __m512 result = _mm512_mask_blend_ps(odd, sin_r, cos_r); + + __m512i need_neg = _mm512_and_si512(q, _mm512_set1_epi32(2)); + __m512 neg_mask = _mm512_castsi512_ps(_mm512_slli_epi32(need_neg, 30)); + result = _mm512_xor_ps(result, neg_mask); + + return _mm512_xor_ps(result, x_sign); + } + + // --- 256-bit sincos --- + constexpr void sincos_f32x8(__m256 x, __m256& out_sin, __m256& out_cos) { + const __m256 sign_mask = _mm256_set1_ps(-0.0f); + __m256 x_sign = _mm256_and_ps(x, sign_mask); + __m256 ax = _mm256_andnot_ps(sign_mask, x); + + __m256 r, r2; __m256i q; + range_reduce_f32x8(ax, r, r2, q); + + __m256 cos_r, sin_r; + sincos_poly_f32x8(r, r2, cos_r, sin_r); + + __m256i odd = _mm256_and_si256(q, _mm256_set1_epi32(1)); + __m256 is_odd = _mm256_castsi256_ps(_mm256_cmpeq_epi32(odd, _mm256_set1_epi32(1))); + + // cos: swap on odd, negate if (q+1)&2 + out_cos = _mm256_blendv_ps(cos_r, sin_r, is_odd); + __m256i cos_neg = _mm256_and_si256( + _mm256_add_epi32(q, _mm256_set1_epi32(1)), _mm256_set1_epi32(2)); + out_cos = _mm256_xor_ps(out_cos, + _mm256_castsi256_ps(_mm256_slli_epi32(cos_neg, 30))); + + // sin: swap on odd, negate if q&2, apply input sign + out_sin = _mm256_blendv_ps(sin_r, cos_r, is_odd); + __m256i sin_neg = _mm256_and_si256(q, _mm256_set1_epi32(2)); + out_sin = _mm256_xor_ps(out_sin, + _mm256_castsi256_ps(_mm256_slli_epi32(sin_neg, 30))); + out_sin = _mm256_xor_ps(out_sin, x_sign); + } + + // --- 512-bit sincos --- + constexpr void sincos_f32x16(__m512 x, __m512& out_sin, __m512& out_cos) { + __m512 x_sign = _mm512_and_ps(x, _mm512_set1_ps(-0.0f)); + __m512 ax = _mm512_abs_ps(x); + + __m512 r, r2; __m512i q; + range_reduce_f32x16(ax, r, r2, q); + + __m512 cos_r, sin_r; + sincos_poly_f32x16(r, r2, cos_r, sin_r); + + __mmask16 odd = _mm512_test_epi32_mask(q, _mm512_set1_epi32(1)); + + // cos + out_cos = _mm512_mask_blend_ps(odd, cos_r, sin_r); + __m512i cos_neg = _mm512_and_si512( + _mm512_add_epi32(q, _mm512_set1_epi32(1)), _mm512_set1_epi32(2)); + out_cos = _mm512_xor_ps(out_cos, + _mm512_castsi512_ps(_mm512_slli_epi32(cos_neg, 30))); + + // sin + out_sin = _mm512_mask_blend_ps(odd, sin_r, cos_r); + __m512i sin_neg = _mm512_and_si512(q, _mm512_set1_epi32(2)); + out_sin = _mm512_xor_ps(out_sin, + _mm512_castsi512_ps(_mm512_slli_epi32(sin_neg, 30))); + out_sin = _mm512_xor_ps(out_sin, x_sign); + } }; } diff --git a/project.json b/project.json index 263f1dc..4c07ec8 100644 --- a/project.json +++ b/project.json @@ -34,7 +34,7 @@ ], "tests":[ { - "name": "F16x86", + "name": "F16-x86-64-sapphirerapids", "implementations": ["tests/VectorF16"], "march": "sapphirerapids", "extends": ["lib-shared"] diff --git a/tests/VectorF16.cpp b/tests/VectorF16.cpp index 30ac645..1934a09 100644 --- a/tests/VectorF16.cpp +++ b/tests/VectorF16.cpp @@ -22,8 +22,9 @@ using namespace Crafter; extern "C" { std::string* RunTest() { + // Test 1: Load/Store functionality { - _Float16 floats[] {0,1,2,3,4,5,6,7,8}; + _Float16 floats[] {0,1,2,3,4,5,6,7}; VectorF16<8, 1> vec1(floats); Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = vec1.Store(); @@ -34,8 +35,9 @@ extern "C" { } } + // Test 2: Addition operator { - _Float16 floats[] {0,1,2,3,4,5,6,7,8}; + _Float16 floats[] {0,1,2,3,4,5,6,7}; VectorF16<8, 1> vec1(floats); VectorF16<8, 1> result = vec1 + vec1; Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = result.Store(); @@ -45,8 +47,348 @@ extern "C" { } } } + + // Test 3: Subtraction operator + { + _Float16 floats[] {0,1,2,3,4,5,6,7}; + VectorF16<8, 1> vec1(floats); + VectorF16<8, 1> result = vec1 - vec1; + Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = result.Store(); + for(std::uint8_t i = 0; i < 8; i++) { + if(stored.v[i] != floats[i] - floats[i]) { + return new std::string("Subtract does not match"); + } + } + } + + // Test 4: Multiplication operator + { + _Float16 floats[] {1,2,3,4,5,6,7,8}; + VectorF16<8, 1> vec1(floats); + VectorF16<8, 1> result = vec1 * vec1; + Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = result.Store(); + for(std::uint8_t i = 0; i < 8; i++) { + if(stored.v[i] != floats[i] * floats[i]) { + return new std::string("Multiply does not match"); + } + } + } + + // Test 5: Division operator + { + _Float16 floats[] {2,4,6,8,10,12,14,16}; + VectorF16<8, 1> vec1(floats); + VectorF16<8, 1> result = vec1 / vec1; + Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = result.Store(); + for(std::uint8_t i = 0; i < 8; i++) { + if(stored.v[i] != floats[i] / floats[i]) { + return new std::string("Divide does not match"); + } + } + } + + // Test 6: Compound addition operator + { + _Float16 floats[] {1,2,3,4,5,6,7,8}; + VectorF16<8, 1> vec1(floats); + VectorF16<8, 1> vec2(floats); + vec1 += vec2; + Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = vec1.Store(); + for(std::uint8_t i = 0; i < 8; i++) { + if(stored.v[i] != floats[i] + floats[i]) { + return new std::string("Compound Add does not match"); + } + } + } + + // Test 7: Compound subtraction operator + { + _Float16 floats[] {1,2,3,4,5,6,7,8}; + VectorF16<8, 1> vec1(floats); + VectorF16<8, 1> vec2(floats); + vec1 -= vec2; + Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = vec1.Store(); + for(std::uint8_t i = 0; i < 8; i++) { + if(stored.v[i] != floats[i] - floats[i]) { + return new std::string("Compound Subtract does not match"); + } + } + } + + // Test 8: Compound multiplication operator + { + _Float16 floats[] {1,2,3,4,5,6,7,8}; + VectorF16<8, 1> vec1(floats); + VectorF16<8, 1> vec2(floats); + vec1 *= vec2; + Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = vec1.Store(); + for(std::uint8_t i = 0; i < 8; i++) { + if(stored.v[i] != floats[i] * floats[i]) { + return new std::string("Compound Multiply does not match"); + } + } + } + + // Test 9: Compound division operator + { + _Float16 floats[] {2,4,6,8,10,12,14,16}; + VectorF16<8, 1> vec1(floats); + VectorF16<8, 1> vec2(floats); + vec1 /= vec2; + Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = vec1.Store(); + for(std::uint8_t i = 0; i < 8; i++) { + if(stored.v[i] != floats[i] / floats[i]) { + return new std::string("Compound Divide does not match"); + } + } + } + + // Test 10: Scalar addition + { + _Float16 floats[] {1,2,3,4,5,6,7,8}; + VectorF16<8, 1> vec1(floats); + VectorF16<8, 1> result = vec1 + _Float16(1.0); + Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = result.Store(); + for(std::uint8_t i = 0; i < 8; i++) { + if(stored.v[i] != floats[i] + 1.0) { + return new std::string("Scalar Add does not match"); + } + } + } + + // Test 11: Scalar subtraction + { + _Float16 floats[] {1,2,3,4,5,6,7,8}; + VectorF16<8, 1> vec1(floats); + VectorF16<8, 1> result = vec1 - _Float16(1.0); + Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = result.Store(); + for(std::uint8_t i = 0; i < 8; i++) { + if(stored.v[i] != floats[i] - 1.0) { + return new std::string("Scalar Subtract does not match"); + } + } + } + + // Test 12: Scalar multiplication + { + _Float16 floats[] {1,2,3,4,5,6,7,8}; + VectorF16<8, 1> vec1(floats); + VectorF16<8, 1> result = vec1 * _Float16(2.0); + Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = result.Store(); + for(std::uint8_t i = 0; i < 8; i++) { + if(stored.v[i] != floats[i] * 2.0) { + return new std::string("Scalar Multiply does not match"); + } + } + } + + // Test 13: Scalar division + { + _Float16 floats[] {2,4,6,8,10,12,14,16}; + VectorF16<8, 1> vec1(floats); + VectorF16<8, 1> result = vec1 / _Float16(2.0); + Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = result.Store(); + for(std::uint8_t i = 0; i < 8; i++) { + if(stored.v[i] != floats[i] / 2.0) { + return new std::string("Scalar Divide does not match"); + } + } + } + + // Test 14: Compound scalar addition + { + _Float16 floats[] {1,2,3,4,5,6,7,8}; + VectorF16<8, 1> vec1(floats); + vec1 += _Float16(1.0); + Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = vec1.Store(); + for(std::uint8_t i = 0; i < 8; i++) { + if(stored.v[i] != floats[i] + 1.0) { + return new std::string("Compound Scalar Add does not match"); + } + } + } + + // Test 15: Compound scalar subtraction + { + _Float16 floats[] {1,2,3,4,5,6,7,8}; + VectorF16<8, 1> vec1(floats); + vec1 -= _Float16(1.0); + Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = vec1.Store(); + for(std::uint8_t i = 0; i < 8; i++) { + if(stored.v[i] != floats[i] - 1.0) { + return new std::string("Compound Scalar Subtract does not match"); + } + } + } + + // Test 16: Compound scalar multiplication + { + _Float16 floats[] {1,2,3,4,5,6,7,8}; + VectorF16<8, 1> vec1(floats); + vec1 *= _Float16(2.0); + Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = vec1.Store(); + for(std::uint8_t i = 0; i < 8; i++) { + if(stored.v[i] != floats[i] * 2.0) { + return new std::string("Compound Scalar Multiply does not match"); + } + } + } + + // Test 17: Compound scalar division + { + _Float16 floats[] {2,4,6,8,10,12,14,16}; + VectorF16<8, 1> vec1(floats); + vec1 /= _Float16(2.0); + Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = vec1.Store(); + for(std::uint8_t i = 0; i < 8; i++) { + if(stored.v[i] != floats[i] / 2.0) { + return new std::string("Compound Scalar Divide does not match"); + } + } + } + + // Test 18: Equality operator + { + _Float16 floats[] {1,2,3,4,5,6,7,8}; + VectorF16<8, 1> vec1(floats); + VectorF16<8, 1> vec2(floats); + if (!(vec1 == vec2)) { + return new std::string("Equality operator does not match"); + } + } + + // Test 19: Inequality operator + { + _Float16 floats1[] {1,2,3,4,5,6,7,8}; + _Float16 floats2[] {2,3,4,5,6,7,8,9}; + VectorF16<8, 1> vec1(floats1); + VectorF16<8, 1> vec2(floats2); + if (!(vec1 != vec2)) { + return new std::string("Inequality operator does not match"); + } + } + + // Test 20: Negation operator + { + _Float16 floats[] {1,2,3,4,5,6,7,8}; + VectorF16<8, 1> vec1(floats); + VectorF16<8, 1> result = -vec1; + Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = result.Store(); + for(std::uint8_t i = 0; i < 8; i++) { + if(stored.v[i] != -floats[i]) { + return new std::string("Negation operator does not match"); + } + } + } + + // Test 21: Length calculation + { + _Float16 floats[] {3,4,0,0,0,0,0,0}; + VectorF16<8, 1> vec1(floats); + _Float16 length = vec1.Length(); + _Float16 expectedLength = 5.0; // sqrt(3^2 + 4^2) + if (std::abs((float)length - (float)expectedLength) > 0.001) { + return new std::string("Length calculation does not match"); + } + } + + // Test 22: Length squared calculation + { + _Float16 floats[] {3,4,0,0,0,0,0,0}; + VectorF16<8, 1> vec1(floats); + _Float16 lengthSq = vec1.LengthSq(); + _Float16 expectedLengthSq = 25.0; // 3^2 + 4^2 + if (std::abs((float)lengthSq - (float)expectedLengthSq) > 0.001) { + return new std::string("Length squared calculation does not match"); + } + } + + // Test 25: Shuffle operation + { + _Float16 floats[] {1,2,3,4,5,6,7,8}; + VectorF16<8, 1> vec1(floats); + // Shuffle indices 0,1,2,3 -> 3,2,1,0 (reverse first 4 elements) + VectorF16<8, 1> result = vec1.template Shuffle<{{3,2,1,0,7,6,5,4}}>(); + Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = result.Store(); + if (stored.v[0] != 4 || stored.v[1] != 3 || stored.v[2] != 2 || stored.v[3] != 1) { + return new std::string("Shuffle operation does not match"); + } + } + + // Test 26: Dot product + { + _Float16 floats1[] {1,2,3,4,0,0,0,0}; + _Float16 floats2[] {2,3,4,5,0,0,0,0}; + VectorF16<8, 1> vec1(floats1); + VectorF16<8, 1> vec2(floats2); + _Float16 dot = VectorF16<8, 1>::Dot(vec1, vec2); + _Float16 expectedDot = 1*2 + 2*3 + 3*4 + 4*5; // 2 + 6 + 12 + 20 = 40 + if (std::abs((float)dot - (float)expectedDot) > 0.001) { + return new std::string("Dot product does not match"); + } + } + + // Test 27: Cross product (for 3D vectors) + { + _Float16 floats1[] {1,2,3,0,0,0,0,0}; + _Float16 floats2[] {4,5,6,0,0,0,0,0}; + VectorF16<3, 1> vec1(floats1); + VectorF16<3, 1> vec2(floats2); + VectorF16<3, 1> result = VectorF16<3, 1>::Cross(vec1, vec2); + Vector<_Float16, 3, VectorF16<3, 1>::Alignment> stored = result.Store(); + // Cross product: (1,2,3) x (4,5,6) = (2*6-3*5, 3*4-1*6, 1*5-2*4) = (-3, 6, -3) + if (stored.v[0] != -3 || stored.v[1] != 6 || stored.v[2] != -3) { + return new std::string("Cross product does not match"); + } + } + + // Test 28: Multiply-Add operation + { + _Float16 floats1[] {1,2,3,4,0,0,0,0}; + _Float16 floats2[] {2,3,4,5,0,0,0,0}; + _Float16 floats3[] {1,1,1,1,0,0,0,0}; + VectorF16<8, 1> vec1(floats1); + VectorF16<8, 1> vec2(floats2); + VectorF16<8, 1> vec3(floats3); + VectorF16<8, 1> result = VectorF16<8, 1>::MulitplyAdd(vec1, vec2, vec3); + Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = result.Store(); + // Should compute (1*2 + 1, 2*3 + 1, 3*4 + 1, 4*5 + 1, ...) = (3, 7, 13, 21, ...) + for(std::uint8_t i = 0; i < 4; i++) { + if(stored.v[i] != floats1[i]*floats2[i] + floats3[i]) { + return new std::string("Multiply-Add operation does not match"); + } + } + } + + // Test 29: Multiply-Subtract operation + { + _Float16 floats1[] {1,2,3,4,0,0,0,0}; + _Float16 floats2[] {2,3,4,5,0,0,0,0}; + _Float16 floats3[] {1,1,1,1,0,0,0,0}; + VectorF16<8, 1> vec1(floats1); + VectorF16<8, 1> vec2(floats2); + VectorF16<8, 1> vec3(floats3); + VectorF16<8, 1> result = VectorF16<8, 1>::MulitplySub(vec1, vec2, vec3); + Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = result.Store(); + // Should compute (1*2 - 1, 2*3 - 1, 3*4 - 1, 4*5 - 1, ...) = (1, 5, 11, 19, ...) + for(std::uint8_t i = 0; i < 4; i++) { + if(stored.v[i] != floats1[i]*floats2[i] - floats3[i]) { + return new std::string("Multiply-Subtract operation does not match"); + } + } + } + + // Test 30: Constructor with single value + { + VectorF16<8, 1> vec1(_Float16(5.0)); + Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = vec1.Store(); + for(std::uint8_t i = 0; i < 8; i++) { + if(stored.v[i] != 5.0) { + return new std::string("Single value constructor does not match"); + } + } + } + return nullptr; } -} - - +} \ No newline at end of file