F16 tests

This commit is contained in:
Jorijn van der Graaf 2026-03-24 02:09:28 +01:00
commit b2b4ca9c4d
3 changed files with 660 additions and 48 deletions

View file

@ -148,7 +148,7 @@ namespace Crafter {
}
constexpr void operator+=(VectorF16<Len, Packing> b) const {
constexpr void operator+=(VectorF16<Len, Packing> b) {
if constexpr(std::is_same_v<VectorType, __m128h>) {
v = _mm_add_ph(v, b.v);
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
@ -158,7 +158,7 @@ namespace Crafter {
}
}
constexpr void operator-=(VectorF16<Len, Packing> b) const {
constexpr void operator-=(VectorF16<Len, Packing> b) {
if constexpr(std::is_same_v<VectorType, __m128h>) {
v = _mm_sub_ph(v, b.v);
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
@ -168,7 +168,7 @@ namespace Crafter {
}
}
constexpr void operator*=(VectorF16<Len, Packing> b) const {
constexpr void operator*=(VectorF16<Len, Packing> b) {
if constexpr(std::is_same_v<VectorType, __m128h>) {
v = _mm_mul_ph(v, b.v);
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
@ -178,7 +178,7 @@ namespace Crafter {
}
}
constexpr void operator/=(VectorF16<Len, Packing> b) const {
constexpr void operator/=(VectorF16<Len, Packing> b) {
if constexpr(std::is_same_v<VectorType, __m128h>) {
v = _mm_div_ph(v, b.v);
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
@ -188,48 +188,48 @@ namespace Crafter {
}
}
constexpr VectorF16<Len, Packing> operator+(_Float16 b) const {
constexpr VectorF16<Len, Packing> operator+(_Float16 b) {
VectorF16<Len, Packing> vB(b);
return this + vB;
return *this + vB;
}
constexpr VectorF16<Len, Packing> operator-(_Float16 b) const {
constexpr VectorF16<Len, Packing> operator-(_Float16 b) {
VectorF16<Len, Packing> vB(b);
return this - vB;
return *this - vB;
}
constexpr VectorF16<Len, Packing> operator*(_Float16 b) const {
constexpr VectorF16<Len, Packing> operator*(_Float16 b) {
VectorF16<Len, Packing> vB(b);
return this * vB;
return *this * vB;
}
constexpr VectorF16<Len, Packing> operator/(_Float16 b) const {
constexpr VectorF16<Len, Packing> operator/(_Float16 b) {
VectorF16<Len, Packing> vB(b);
return this / vB;
return *this / vB;
}
constexpr void operator+=(_Float16 b) const {
constexpr void operator+=(_Float16 b) {
VectorF16<Len, Packing> vB(b);
this += vB;
*this += vB;
}
constexpr void operator-=(_Float16 b) const {
constexpr void operator-=(_Float16 b) {
VectorF16<Len, Packing> vB(b);
this -= vB;
*this -= vB;
}
constexpr void operator*=(_Float16 b) const {
constexpr void operator*=(_Float16 b) {
VectorF16<Len, Packing> vB(b);
this *= vB;
*this *= vB;
}
constexpr void operator/=(_Float16 b) const {
constexpr void operator/=(_Float16 b) {
VectorF16<Len, Packing> vB(b);
this /= vB;
*this /= vB;
}
constexpr VectorF16<Len, Packing> operator-(){
return Negate<GetAllTrue>();
return Negate<GetAllTrue()>();
}
constexpr bool operator==(VectorF16<Len, Packing> b) const {
@ -281,28 +281,88 @@ namespace Crafter {
}
constexpr VectorF16<Len, Packing> Cos() {
if constexpr(std::is_same_v<VectorType, __m128h>) {
return VectorF16<Len, Packing>(_mm_cos_ph(v));
} else if constexpr(std::is_same_v<VectorType, __m128h>) {
return VectorF16<Len, Packing>(_mm256_cos_ph(v));
if constexpr (std::is_same_v<VectorType, __m128h>) {
__m256 wide = _mm256_cvtph_ps(_mm_castph_si128(v));
wide = cos_f32x8(wide);
return VectorF16<Len, Packing>(
_mm_castsi128_ph(_mm256_cvtps_ph(wide, _MM_FROUND_TO_NEAREST_INT)));
} else if constexpr (std::is_same_v<VectorType, __m256h>) {
__m512 wide = _mm512_cvtph_ps(_mm256_castph_si256(v));
wide = cos_f32x16(wide);
return VectorF16<Len, Packing>(
_mm256_castsi256_ph(_mm512_cvtps_ph(wide, _MM_FROUND_TO_NEAREST_INT)));
} else {
return VectorF16<Len, Packing>(_mm512_cos_ph(v));
__m256i lo = _mm512_castsi512_si256(_mm512_castph_si512(v));
__m256i hi = _mm512_extracti64x4_epi64(_mm512_castph_si512(v), 1);
__m256i lo_ph = _mm512_cvtps_ph(cos_f32x16(_mm512_cvtph_ps(lo)), _MM_FROUND_TO_NEAREST_INT);
__m256i hi_ph = _mm512_cvtps_ph(cos_f32x16(_mm512_cvtph_ps(hi)), _MM_FROUND_TO_NEAREST_INT);
return VectorF16<Len, Packing>(
_mm512_castsi512_ph(_mm512_inserti64x4(_mm512_castsi256_si512(lo_ph), hi_ph, 1)));
}
}
}
constexpr VectorF16<Len, Packing> Sin() {
if constexpr(std::is_same_v<VectorType, __m128h>) {
return VectorF16<Len, Packing>(_mm_sin_ph(v));
} else if constexpr(std::is_same_v<VectorType, __m128h>) {
return VectorF16<Len, Packing>(_mm256_sin_ph(v));
if constexpr (std::is_same_v<VectorType, __m128h>) {
__m256 wide = _mm256_cvtph_ps(_mm_castph_si128(v));
wide = sin_f32x8(wide);
return VectorF16<Len, Packing>(_mm_castsi128_ph(_mm256_cvtps_ph(wide, _MM_FROUND_TO_NEAREST_INT)));
} else if constexpr (std::is_same_v<VectorType, __m256h>) {
__m512 wide = _mm512_cvtph_ps(_mm256_castph_si256(v));
wide = sin_f32x16(wide);
return VectorF16<Len, Packing>(_mm256_castsi256_ph(_mm512_cvtps_ph(wide, _MM_FROUND_TO_NEAREST_INT)));
} else {
return VectorF16<Len, Packing>(_mm512_sin_ph(v));
__m256i lo = _mm512_castsi512_si256(_mm512_castph_si512(v));
__m256i hi = _mm512_extracti64x4_epi64(_mm512_castph_si512(v), 1);
__m256i lo_ph = _mm512_cvtps_ph(sin_f32x16(_mm512_cvtph_ps(lo)), _MM_FROUND_TO_NEAREST_INT);
__m256i hi_ph = _mm512_cvtps_ph(sin_f32x16(_mm512_cvtph_ps(hi)), _MM_FROUND_TO_NEAREST_INT);
return VectorF16<Len, Packing>(_mm512_castsi512_ph(_mm512_inserti64x4(_mm512_castsi256_si512(lo_ph), hi_ph, 1)));
}
}
}
std::tuple<VectorF16<Len, Packing>, VectorF16<Len, Packing>> SinCos() {
if constexpr (std::is_same_v<VectorType, __m128h>) {
__m256 wide = _mm256_cvtph_ps(_mm_castph_si128(v));
__m256 s, c;
sincos_f32x8(wide, s, c);
return {
VectorF16<Len, Packing>(_mm_castsi128_ph(_mm256_cvtps_ph(s, _MM_FROUND_TO_NEAREST_INT))),
VectorF16<Len, Packing>(_mm_castsi128_ph(_mm256_cvtps_ph(c, _MM_FROUND_TO_NEAREST_INT)))
};
} else if constexpr (std::is_same_v<VectorType, __m256h>) {
__m512 wide = _mm512_cvtph_ps(_mm256_castph_si256(v));
__m512 s, c;
sincos_f32x16(wide, s, c);
return {
VectorF16<Len, Packing>(_mm256_castsi256_ph(_mm512_cvtps_ph(s, _MM_FROUND_TO_NEAREST_INT))),
VectorF16<Len, Packing>(_mm256_castsi256_ph(_mm512_cvtps_ph(c, _MM_FROUND_TO_NEAREST_INT)))
};
} else {
__m256i lo = _mm512_castsi512_si256(_mm512_castph_si512(v));
__m256i hi = _mm512_extracti64x4_epi64(_mm512_castph_si512(v), 1);
__m512 s_lo, c_lo, s_hi, c_hi;
sincos_f32x16(_mm512_cvtph_ps(lo), s_lo, c_lo);
sincos_f32x16(_mm512_cvtph_ps(hi), s_hi, c_hi);
auto pack = [](__m256i lo_ph, __m256i hi_ph) {
return _mm512_castsi512_ph(_mm512_inserti64x4(_mm512_castsi256_si512(lo_ph), hi_ph, 1));
};
return {
VectorF16<Len, Packing>(pack(_mm512_cvtps_ph(s_lo, _MM_FROUND_TO_NEAREST_INT), _mm512_cvtps_ph(s_hi, _MM_FROUND_TO_NEAREST_INT))),
VectorF16<Len, Packing>(pack( _mm512_cvtps_ph(c_lo, _MM_FROUND_TO_NEAREST_INT), _mm512_cvtps_ph(c_hi, _MM_FROUND_TO_NEAREST_INT)))
};
}
}
template <std::array<bool, Len> values>
constexpr std::array<std::uint16_t, Len> Negate() {
std::array<std::uint16_t, Len> mask = GetShuffleMaskEpi32<values>();
constexpr VectorF16<Len, Packing> Negate() {
std::array<std::uint16_t, Len> mask = GetNegateMask<values>();
if constexpr(std::is_same_v<VectorType, __m128h>) {
return VectorF16<Len, Packing>(_mm_castsi128_ph(_mm_xor_si128(_mm_castph_si128(v), _mm_loadu_epi16(mask.data()))));
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
@ -341,21 +401,21 @@ namespace Crafter {
static constexpr VectorF16<Len, Packing> MulitplyAdd(VectorF16<Len, Packing> a, VectorF16<Len, Packing> b, VectorF16<Len, Packing> add) {
if constexpr(std::is_same_v<VectorType, __m128h>) {
return VectorF16<Len, Packing>(_mm_fmadd_ph(a.v, b.v, add));
return VectorF16<Len, Packing>(_mm_fmadd_ph(a.v, b.v, add.v));
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
return VectorF16<Len, Packing>(_mm256_fmadd_ph(a.v, b.v, add));
return VectorF16<Len, Packing>(_mm256_fmadd_ph(a.v, b.v, add.v));
} else {
return VectorF16<Len, Packing>(_mm512_fmadd_ph(a.v, b.v, add));
return VectorF16<Len, Packing>(_mm512_fmadd_ph(a.v, b.v, add.v));
}
}
static constexpr VectorF16<Len, Packing> MulitplySub(VectorF16<Len, Packing> a, VectorF16<Len, Packing> b, VectorF16<Len, Packing> sub) {
if constexpr(std::is_same_v<VectorType, __m128h>) {
return VectorF16<Len, Packing>(_mm_fmsub_ph(a.v, b.v, sub));
return VectorF16<Len, Packing>(_mm_fmsub_ph(a.v, b.v, sub.v));
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
return VectorF16<Len, Packing>(_mm256_fmsub_ph(a.v, b.v, sub));
return VectorF16<Len, Packing>(_mm256_fmsub_ph(a.v, b.v, sub.v));
} else {
return VectorF16<Len, Packing>(_mm512_fmsub_ph(a.v, b.v, sub));
return VectorF16<Len, Packing>(_mm512_fmsub_ph(a.v, b.v, sub.v));
}
}
@ -1227,7 +1287,7 @@ namespace Crafter {
return shuffleMask;
}
consteval std::array<bool, Len> GetAllTrue() {
static consteval std::array<bool, Len> GetAllTrue() {
std::array<bool, Len> arr{};
arr.fill(true);
return arr;
@ -1288,6 +1348,216 @@ namespace Crafter {
}
return mask;
}
static constexpr float two_over_pi = 0.6366197723675814f;
static constexpr float pi_over_2_hi = 1.5707963267341256f;
static constexpr float pi_over_2_lo = 6.077100506506192e-11f;
// Cos polynomial on [-pi/4, pi/4]: c0 + c2*r^2 + c4*r^4 + ...
static constexpr float c0 = 1.0f;
static constexpr float c2 = -0.4999999642372f;
static constexpr float c4 = 0.0416666418707f;
static constexpr float c6 = -0.0013888397720f;
static constexpr float c8 = 0.0000248015873f;
static constexpr float c10 = -0.0000002752258f;
// Sin polynomial on [-pi/4, pi/4]: r * (1 + s1*r^2 + s3*r^4 + ...)
static constexpr float s1 = -0.1666666641831f;
static constexpr float s3 = 0.0083333293858f;
static constexpr float s5 = -0.0001984090955f;
static constexpr float s7 = 0.0000027526372f;
static constexpr float s9 = -0.0000000239013f;
// Reduce |x| into [-pi/4, pi/4], return reduced value and quadrant
constexpr void range_reduce_f32x8(__m256 ax, __m256& r, __m256& r2, __m256i& q) {
__m256 fq = _mm256_round_ps(_mm256_mul_ps(ax, _mm256_set1_ps(two_over_pi)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
q = _mm256_cvtps_epi32(fq);
r = _mm256_sub_ps(ax, _mm256_mul_ps(fq, _mm256_set1_ps(pi_over_2_hi)));
r = _mm256_sub_ps(r, _mm256_mul_ps(fq, _mm256_set1_ps(pi_over_2_lo)));
r2 = _mm256_mul_ps(r, r);
}
// cos(x): use cos_poly when q even, sin_poly when q odd; negate if (q+1)&2
constexpr __m256 cos_f32x8(__m256 x) {
const __m256 sign_mask = _mm256_set1_ps(-0.0f);
__m256 ax = _mm256_andnot_ps(sign_mask, x);
__m256 r, r2; __m256i q;
range_reduce_f32x8(ax, r, r2, q);
__m256 cos_r, sin_r;
sincos_poly_f32x8(r, r2, cos_r, sin_r);
__m256i odd = _mm256_and_si256(q, _mm256_set1_epi32(1));
__m256 use_sin = _mm256_castsi256_ps(_mm256_cmpeq_epi32(odd, _mm256_set1_epi32(1)));
__m256 result = _mm256_blendv_ps(cos_r, sin_r, use_sin);
__m256i need_neg = _mm256_and_si256(
_mm256_add_epi32(q, _mm256_set1_epi32(1)), _mm256_set1_epi32(2));
__m256 neg_mask = _mm256_castsi256_ps(_mm256_slli_epi32(need_neg, 30));
return _mm256_xor_ps(result, neg_mask);
}
constexpr void sincos_poly_f32x8(__m256 r, __m256 r2, __m256& cos_r, __m256& sin_r) {
cos_r = _mm256_fmadd_ps(_mm256_set1_ps(c10), r2, _mm256_set1_ps(c8));
cos_r = _mm256_fmadd_ps(cos_r, r2, _mm256_set1_ps(c6));
cos_r = _mm256_fmadd_ps(cos_r, r2, _mm256_set1_ps(c4));
cos_r = _mm256_fmadd_ps(cos_r, r2, _mm256_set1_ps(c2));
cos_r = _mm256_fmadd_ps(cos_r, r2, _mm256_set1_ps(c0));
sin_r = _mm256_fmadd_ps(_mm256_set1_ps(s9), r2, _mm256_set1_ps(s7));
sin_r = _mm256_fmadd_ps(sin_r, r2, _mm256_set1_ps(s5));
sin_r = _mm256_fmadd_ps(sin_r, r2, _mm256_set1_ps(s3));
sin_r = _mm256_fmadd_ps(sin_r, r2, _mm256_set1_ps(s1));
sin_r = _mm256_fmadd_ps(sin_r, r2, _mm256_set1_ps(1.0f));
sin_r = _mm256_mul_ps(sin_r, r);
}
// sin(x): use sin_poly when q even, cos_poly when q odd; negate if q&2; respect input sign
constexpr __m256 sin_f32x8(__m256 x) {
const __m256 sign_mask = _mm256_set1_ps(-0.0f);
__m256 x_sign = _mm256_and_ps(x, sign_mask);
__m256 ax = _mm256_andnot_ps(sign_mask, x);
__m256 r, r2; __m256i q;
range_reduce_f32x8(ax, r, r2, q);
__m256 cos_r, sin_r;
sincos_poly_f32x8(r, r2, cos_r, sin_r);
__m256i odd = _mm256_and_si256(q, _mm256_set1_epi32(1));
__m256 use_cos = _mm256_castsi256_ps(_mm256_cmpeq_epi32(odd, _mm256_set1_epi32(1)));
__m256 result = _mm256_blendv_ps(sin_r, cos_r, use_cos);
__m256i need_neg = _mm256_and_si256(q, _mm256_set1_epi32(2));
__m256 neg_mask = _mm256_castsi256_ps(_mm256_slli_epi32(need_neg, 30));
result = _mm256_xor_ps(result, neg_mask);
// Apply original sign of x
return _mm256_xor_ps(result, x_sign);
}
// --- 512-bit helpers ---
constexpr void range_reduce_f32x16(__m512 ax, __m512& r, __m512& r2, __m512i& q) {
__m512 fq = _mm512_roundscale_ps(_mm512_mul_ps(ax, _mm512_set1_ps(two_over_pi)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
q = _mm512_cvtps_epi32(fq);
r = _mm512_sub_ps(ax, _mm512_mul_ps(fq, _mm512_set1_ps(pi_over_2_hi)));
r = _mm512_sub_ps(r, _mm512_mul_ps(fq, _mm512_set1_ps(pi_over_2_lo)));
r2 = _mm512_mul_ps(r, r);
}
constexpr void sincos_poly_f32x16(__m512 r, __m512 r2, __m512& cos_r, __m512& sin_r) {
cos_r = _mm512_fmadd_ps(_mm512_set1_ps(c10), r2, _mm512_set1_ps(c8));
cos_r = _mm512_fmadd_ps(cos_r, r2, _mm512_set1_ps(c6));
cos_r = _mm512_fmadd_ps(cos_r, r2, _mm512_set1_ps(c4));
cos_r = _mm512_fmadd_ps(cos_r, r2, _mm512_set1_ps(c2));
cos_r = _mm512_fmadd_ps(cos_r, r2, _mm512_set1_ps(c0));
sin_r = _mm512_fmadd_ps(_mm512_set1_ps(s9), r2, _mm512_set1_ps(s7));
sin_r = _mm512_fmadd_ps(sin_r, r2, _mm512_set1_ps(s5));
sin_r = _mm512_fmadd_ps(sin_r, r2, _mm512_set1_ps(s3));
sin_r = _mm512_fmadd_ps(sin_r, r2, _mm512_set1_ps(s1));
sin_r = _mm512_fmadd_ps(sin_r, r2, _mm512_set1_ps(1.0f));
sin_r = _mm512_mul_ps(sin_r, r);
}
constexpr __m512 cos_f32x16(__m512 x) {
__m512 ax = _mm512_abs_ps(x);
__m512 r, r2; __m512i q;
range_reduce_f32x16(ax, r, r2, q);
__m512 cos_r, sin_r;
sincos_poly_f32x16(r, r2, cos_r, sin_r);
__mmask16 odd = _mm512_test_epi32_mask(q, _mm512_set1_epi32(1));
__m512 result = _mm512_mask_blend_ps(odd, cos_r, sin_r);
__m512i need_neg = _mm512_and_si512(
_mm512_add_epi32(q, _mm512_set1_epi32(1)), _mm512_set1_epi32(2));
__m512 neg_mask = _mm512_castsi512_ps(_mm512_slli_epi32(need_neg, 30));
return _mm512_xor_ps(result, neg_mask);
}
constexpr __m512 sin_f32x16(__m512 x) {
__m512 x_sign = _mm512_and_ps(x, _mm512_set1_ps(-0.0f));
__m512 ax = _mm512_abs_ps(x);
__m512 r, r2; __m512i q;
range_reduce_f32x16(ax, r, r2, q);
__m512 cos_r, sin_r;
sincos_poly_f32x16(r, r2, cos_r, sin_r);
__mmask16 odd = _mm512_test_epi32_mask(q, _mm512_set1_epi32(1));
__m512 result = _mm512_mask_blend_ps(odd, sin_r, cos_r);
__m512i need_neg = _mm512_and_si512(q, _mm512_set1_epi32(2));
__m512 neg_mask = _mm512_castsi512_ps(_mm512_slli_epi32(need_neg, 30));
result = _mm512_xor_ps(result, neg_mask);
return _mm512_xor_ps(result, x_sign);
}
// --- 256-bit sincos ---
constexpr void sincos_f32x8(__m256 x, __m256& out_sin, __m256& out_cos) {
const __m256 sign_mask = _mm256_set1_ps(-0.0f);
__m256 x_sign = _mm256_and_ps(x, sign_mask);
__m256 ax = _mm256_andnot_ps(sign_mask, x);
__m256 r, r2; __m256i q;
range_reduce_f32x8(ax, r, r2, q);
__m256 cos_r, sin_r;
sincos_poly_f32x8(r, r2, cos_r, sin_r);
__m256i odd = _mm256_and_si256(q, _mm256_set1_epi32(1));
__m256 is_odd = _mm256_castsi256_ps(_mm256_cmpeq_epi32(odd, _mm256_set1_epi32(1)));
// cos: swap on odd, negate if (q+1)&2
out_cos = _mm256_blendv_ps(cos_r, sin_r, is_odd);
__m256i cos_neg = _mm256_and_si256(
_mm256_add_epi32(q, _mm256_set1_epi32(1)), _mm256_set1_epi32(2));
out_cos = _mm256_xor_ps(out_cos,
_mm256_castsi256_ps(_mm256_slli_epi32(cos_neg, 30)));
// sin: swap on odd, negate if q&2, apply input sign
out_sin = _mm256_blendv_ps(sin_r, cos_r, is_odd);
__m256i sin_neg = _mm256_and_si256(q, _mm256_set1_epi32(2));
out_sin = _mm256_xor_ps(out_sin,
_mm256_castsi256_ps(_mm256_slli_epi32(sin_neg, 30)));
out_sin = _mm256_xor_ps(out_sin, x_sign);
}
// --- 512-bit sincos ---
constexpr void sincos_f32x16(__m512 x, __m512& out_sin, __m512& out_cos) {
__m512 x_sign = _mm512_and_ps(x, _mm512_set1_ps(-0.0f));
__m512 ax = _mm512_abs_ps(x);
__m512 r, r2; __m512i q;
range_reduce_f32x16(ax, r, r2, q);
__m512 cos_r, sin_r;
sincos_poly_f32x16(r, r2, cos_r, sin_r);
__mmask16 odd = _mm512_test_epi32_mask(q, _mm512_set1_epi32(1));
// cos
out_cos = _mm512_mask_blend_ps(odd, cos_r, sin_r);
__m512i cos_neg = _mm512_and_si512(
_mm512_add_epi32(q, _mm512_set1_epi32(1)), _mm512_set1_epi32(2));
out_cos = _mm512_xor_ps(out_cos,
_mm512_castsi512_ps(_mm512_slli_epi32(cos_neg, 30)));
// sin
out_sin = _mm512_mask_blend_ps(odd, sin_r, cos_r);
__m512i sin_neg = _mm512_and_si512(q, _mm512_set1_epi32(2));
out_sin = _mm512_xor_ps(out_sin,
_mm512_castsi512_ps(_mm512_slli_epi32(sin_neg, 30)));
out_sin = _mm512_xor_ps(out_sin, x_sign);
}
};
}

View file

@ -34,7 +34,7 @@
],
"tests":[
{
"name": "F16x86",
"name": "F16-x86-64-sapphirerapids",
"implementations": ["tests/VectorF16"],
"march": "sapphirerapids",
"extends": ["lib-shared"]

View file

@ -22,8 +22,9 @@ using namespace Crafter;
extern "C" {
std::string* RunTest() {
// Test 1: Load/Store functionality
{
_Float16 floats[] {0,1,2,3,4,5,6,7,8};
_Float16 floats[] {0,1,2,3,4,5,6,7};
VectorF16<8, 1> vec1(floats);
Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = vec1.Store();
@ -34,8 +35,9 @@ extern "C" {
}
}
// Test 2: Addition operator
{
_Float16 floats[] {0,1,2,3,4,5,6,7,8};
_Float16 floats[] {0,1,2,3,4,5,6,7};
VectorF16<8, 1> vec1(floats);
VectorF16<8, 1> result = vec1 + vec1;
Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = result.Store();
@ -45,8 +47,348 @@ extern "C" {
}
}
}
// Test 3: Subtraction operator
{
_Float16 floats[] {0,1,2,3,4,5,6,7};
VectorF16<8, 1> vec1(floats);
VectorF16<8, 1> result = vec1 - vec1;
Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = result.Store();
for(std::uint8_t i = 0; i < 8; i++) {
if(stored.v[i] != floats[i] - floats[i]) {
return new std::string("Subtract does not match");
}
}
}
// Test 4: Multiplication operator
{
_Float16 floats[] {1,2,3,4,5,6,7,8};
VectorF16<8, 1> vec1(floats);
VectorF16<8, 1> result = vec1 * vec1;
Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = result.Store();
for(std::uint8_t i = 0; i < 8; i++) {
if(stored.v[i] != floats[i] * floats[i]) {
return new std::string("Multiply does not match");
}
}
}
// Test 5: Division operator
{
_Float16 floats[] {2,4,6,8,10,12,14,16};
VectorF16<8, 1> vec1(floats);
VectorF16<8, 1> result = vec1 / vec1;
Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = result.Store();
for(std::uint8_t i = 0; i < 8; i++) {
if(stored.v[i] != floats[i] / floats[i]) {
return new std::string("Divide does not match");
}
}
}
// Test 6: Compound addition operator
{
_Float16 floats[] {1,2,3,4,5,6,7,8};
VectorF16<8, 1> vec1(floats);
VectorF16<8, 1> vec2(floats);
vec1 += vec2;
Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = vec1.Store();
for(std::uint8_t i = 0; i < 8; i++) {
if(stored.v[i] != floats[i] + floats[i]) {
return new std::string("Compound Add does not match");
}
}
}
// Test 7: Compound subtraction operator
{
_Float16 floats[] {1,2,3,4,5,6,7,8};
VectorF16<8, 1> vec1(floats);
VectorF16<8, 1> vec2(floats);
vec1 -= vec2;
Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = vec1.Store();
for(std::uint8_t i = 0; i < 8; i++) {
if(stored.v[i] != floats[i] - floats[i]) {
return new std::string("Compound Subtract does not match");
}
}
}
// Test 8: Compound multiplication operator
{
_Float16 floats[] {1,2,3,4,5,6,7,8};
VectorF16<8, 1> vec1(floats);
VectorF16<8, 1> vec2(floats);
vec1 *= vec2;
Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = vec1.Store();
for(std::uint8_t i = 0; i < 8; i++) {
if(stored.v[i] != floats[i] * floats[i]) {
return new std::string("Compound Multiply does not match");
}
}
}
// Test 9: Compound division operator
{
_Float16 floats[] {2,4,6,8,10,12,14,16};
VectorF16<8, 1> vec1(floats);
VectorF16<8, 1> vec2(floats);
vec1 /= vec2;
Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = vec1.Store();
for(std::uint8_t i = 0; i < 8; i++) {
if(stored.v[i] != floats[i] / floats[i]) {
return new std::string("Compound Divide does not match");
}
}
}
// Test 10: Scalar addition
{
_Float16 floats[] {1,2,3,4,5,6,7,8};
VectorF16<8, 1> vec1(floats);
VectorF16<8, 1> result = vec1 + _Float16(1.0);
Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = result.Store();
for(std::uint8_t i = 0; i < 8; i++) {
if(stored.v[i] != floats[i] + 1.0) {
return new std::string("Scalar Add does not match");
}
}
}
// Test 11: Scalar subtraction
{
_Float16 floats[] {1,2,3,4,5,6,7,8};
VectorF16<8, 1> vec1(floats);
VectorF16<8, 1> result = vec1 - _Float16(1.0);
Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = result.Store();
for(std::uint8_t i = 0; i < 8; i++) {
if(stored.v[i] != floats[i] - 1.0) {
return new std::string("Scalar Subtract does not match");
}
}
}
// Test 12: Scalar multiplication
{
_Float16 floats[] {1,2,3,4,5,6,7,8};
VectorF16<8, 1> vec1(floats);
VectorF16<8, 1> result = vec1 * _Float16(2.0);
Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = result.Store();
for(std::uint8_t i = 0; i < 8; i++) {
if(stored.v[i] != floats[i] * 2.0) {
return new std::string("Scalar Multiply does not match");
}
}
}
// Test 13: Scalar division
{
_Float16 floats[] {2,4,6,8,10,12,14,16};
VectorF16<8, 1> vec1(floats);
VectorF16<8, 1> result = vec1 / _Float16(2.0);
Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = result.Store();
for(std::uint8_t i = 0; i < 8; i++) {
if(stored.v[i] != floats[i] / 2.0) {
return new std::string("Scalar Divide does not match");
}
}
}
// Test 14: Compound scalar addition
{
_Float16 floats[] {1,2,3,4,5,6,7,8};
VectorF16<8, 1> vec1(floats);
vec1 += _Float16(1.0);
Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = vec1.Store();
for(std::uint8_t i = 0; i < 8; i++) {
if(stored.v[i] != floats[i] + 1.0) {
return new std::string("Compound Scalar Add does not match");
}
}
}
// Test 15: Compound scalar subtraction
{
_Float16 floats[] {1,2,3,4,5,6,7,8};
VectorF16<8, 1> vec1(floats);
vec1 -= _Float16(1.0);
Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = vec1.Store();
for(std::uint8_t i = 0; i < 8; i++) {
if(stored.v[i] != floats[i] - 1.0) {
return new std::string("Compound Scalar Subtract does not match");
}
}
}
// Test 16: Compound scalar multiplication
{
_Float16 floats[] {1,2,3,4,5,6,7,8};
VectorF16<8, 1> vec1(floats);
vec1 *= _Float16(2.0);
Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = vec1.Store();
for(std::uint8_t i = 0; i < 8; i++) {
if(stored.v[i] != floats[i] * 2.0) {
return new std::string("Compound Scalar Multiply does not match");
}
}
}
// Test 17: Compound scalar division
{
_Float16 floats[] {2,4,6,8,10,12,14,16};
VectorF16<8, 1> vec1(floats);
vec1 /= _Float16(2.0);
Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = vec1.Store();
for(std::uint8_t i = 0; i < 8; i++) {
if(stored.v[i] != floats[i] / 2.0) {
return new std::string("Compound Scalar Divide does not match");
}
}
}
// Test 18: Equality operator
{
_Float16 floats[] {1,2,3,4,5,6,7,8};
VectorF16<8, 1> vec1(floats);
VectorF16<8, 1> vec2(floats);
if (!(vec1 == vec2)) {
return new std::string("Equality operator does not match");
}
}
// Test 19: Inequality operator
{
_Float16 floats1[] {1,2,3,4,5,6,7,8};
_Float16 floats2[] {2,3,4,5,6,7,8,9};
VectorF16<8, 1> vec1(floats1);
VectorF16<8, 1> vec2(floats2);
if (!(vec1 != vec2)) {
return new std::string("Inequality operator does not match");
}
}
// Test 20: Negation operator
{
_Float16 floats[] {1,2,3,4,5,6,7,8};
VectorF16<8, 1> vec1(floats);
VectorF16<8, 1> result = -vec1;
Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = result.Store();
for(std::uint8_t i = 0; i < 8; i++) {
if(stored.v[i] != -floats[i]) {
return new std::string("Negation operator does not match");
}
}
}
// Test 21: Length calculation
{
_Float16 floats[] {3,4,0,0,0,0,0,0};
VectorF16<8, 1> vec1(floats);
_Float16 length = vec1.Length();
_Float16 expectedLength = 5.0; // sqrt(3^2 + 4^2)
if (std::abs((float)length - (float)expectedLength) > 0.001) {
return new std::string("Length calculation does not match");
}
}
// Test 22: Length squared calculation
{
_Float16 floats[] {3,4,0,0,0,0,0,0};
VectorF16<8, 1> vec1(floats);
_Float16 lengthSq = vec1.LengthSq();
_Float16 expectedLengthSq = 25.0; // 3^2 + 4^2
if (std::abs((float)lengthSq - (float)expectedLengthSq) > 0.001) {
return new std::string("Length squared calculation does not match");
}
}
// Test 25: Shuffle operation
{
_Float16 floats[] {1,2,3,4,5,6,7,8};
VectorF16<8, 1> vec1(floats);
// Shuffle indices 0,1,2,3 -> 3,2,1,0 (reverse first 4 elements)
VectorF16<8, 1> result = vec1.template Shuffle<{{3,2,1,0,7,6,5,4}}>();
Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = result.Store();
if (stored.v[0] != 4 || stored.v[1] != 3 || stored.v[2] != 2 || stored.v[3] != 1) {
return new std::string("Shuffle operation does not match");
}
}
// Test 26: Dot product
{
_Float16 floats1[] {1,2,3,4,0,0,0,0};
_Float16 floats2[] {2,3,4,5,0,0,0,0};
VectorF16<8, 1> vec1(floats1);
VectorF16<8, 1> vec2(floats2);
_Float16 dot = VectorF16<8, 1>::Dot(vec1, vec2);
_Float16 expectedDot = 1*2 + 2*3 + 3*4 + 4*5; // 2 + 6 + 12 + 20 = 40
if (std::abs((float)dot - (float)expectedDot) > 0.001) {
return new std::string("Dot product does not match");
}
}
// Test 27: Cross product (for 3D vectors)
{
_Float16 floats1[] {1,2,3,0,0,0,0,0};
_Float16 floats2[] {4,5,6,0,0,0,0,0};
VectorF16<3, 1> vec1(floats1);
VectorF16<3, 1> vec2(floats2);
VectorF16<3, 1> result = VectorF16<3, 1>::Cross(vec1, vec2);
Vector<_Float16, 3, VectorF16<3, 1>::Alignment> stored = result.Store();
// Cross product: (1,2,3) x (4,5,6) = (2*6-3*5, 3*4-1*6, 1*5-2*4) = (-3, 6, -3)
if (stored.v[0] != -3 || stored.v[1] != 6 || stored.v[2] != -3) {
return new std::string("Cross product does not match");
}
}
// Test 28: Multiply-Add operation
{
_Float16 floats1[] {1,2,3,4,0,0,0,0};
_Float16 floats2[] {2,3,4,5,0,0,0,0};
_Float16 floats3[] {1,1,1,1,0,0,0,0};
VectorF16<8, 1> vec1(floats1);
VectorF16<8, 1> vec2(floats2);
VectorF16<8, 1> vec3(floats3);
VectorF16<8, 1> result = VectorF16<8, 1>::MulitplyAdd(vec1, vec2, vec3);
Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = result.Store();
// Should compute (1*2 + 1, 2*3 + 1, 3*4 + 1, 4*5 + 1, ...) = (3, 7, 13, 21, ...)
for(std::uint8_t i = 0; i < 4; i++) {
if(stored.v[i] != floats1[i]*floats2[i] + floats3[i]) {
return new std::string("Multiply-Add operation does not match");
}
}
}
// Test 29: Multiply-Subtract operation
{
_Float16 floats1[] {1,2,3,4,0,0,0,0};
_Float16 floats2[] {2,3,4,5,0,0,0,0};
_Float16 floats3[] {1,1,1,1,0,0,0,0};
VectorF16<8, 1> vec1(floats1);
VectorF16<8, 1> vec2(floats2);
VectorF16<8, 1> vec3(floats3);
VectorF16<8, 1> result = VectorF16<8, 1>::MulitplySub(vec1, vec2, vec3);
Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = result.Store();
// Should compute (1*2 - 1, 2*3 - 1, 3*4 - 1, 4*5 - 1, ...) = (1, 5, 11, 19, ...)
for(std::uint8_t i = 0; i < 4; i++) {
if(stored.v[i] != floats1[i]*floats2[i] - floats3[i]) {
return new std::string("Multiply-Subtract operation does not match");
}
}
}
// Test 30: Constructor with single value
{
VectorF16<8, 1> vec1(_Float16(5.0));
Vector<_Float16, 8, VectorF16<8, 1>::Alignment> stored = vec1.Store();
for(std::uint8_t i = 0; i < 8; i++) {
if(stored.v[i] != 5.0) {
return new std::string("Single value constructor does not match");
}
}
}
return nullptr;
}
}