Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 48 additions & 38 deletions ops/dot-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,12 @@ HWY_MAYBE_UNUSED double ConditionNumber(const WT* HWY_RESTRICT w,
Decompress2(df, packed_v, i, v0, v1);
const VF mul0 = hn::Mul(w0, v0);
const VF mul1 = hn::Mul(w1, v1);
UpdateCascadedSums(df, mul0, sum, sum_err);
UpdateCascadedSums(df, mul1, sum, sum_err);
UpdateCascadedSums(df, hn::Abs(mul0), sum_abs, sum_abs_err);
UpdateCascadedSums(df, hn::Abs(mul1), sum_abs, sum_abs_err);
gcpp::HWY_NAMESPACE::UpdateCascadedSums(df, mul0, sum, sum_err);
gcpp::HWY_NAMESPACE::UpdateCascadedSums(df, mul1, sum, sum_err);
gcpp::HWY_NAMESPACE::UpdateCascadedSums(df, hn::Abs(mul0), sum_abs,
sum_abs_err);
gcpp::HWY_NAMESPACE::UpdateCascadedSums(df, hn::Abs(mul1), sum_abs,
sum_abs_err);
}
}

Expand All @@ -91,15 +93,18 @@ HWY_MAYBE_UNUSED double ConditionNumber(const WT* HWY_RESTRICT w,
const VF w0 = hn::Load(df, padded_w + padded_pos);
const VF v0 = hn::Load(df, padded_v + padded_pos);
const VF mul = hn::Mul(w0, v0);
UpdateCascadedSums(df, mul, sum, sum_err);
UpdateCascadedSums(df, hn::Abs(mul), sum_abs, sum_abs_err);
gcpp::HWY_NAMESPACE::UpdateCascadedSums(df, mul, sum, sum_err);
gcpp::HWY_NAMESPACE::UpdateCascadedSums(df, hn::Abs(mul), sum_abs,
sum_abs_err);
}
}

const float div = hwy::ScalarAbs(ReduceCascadedSums(df, sum, sum_err));
const float div =
hwy::ScalarAbs(gcpp::HWY_NAMESPACE::ReduceCascadedSums(df, sum, sum_err));
if (div == 0.0f) return hn::GetLane(hn::Inf(df));
const double cond = 2.0 * ReduceCascadedSums(df, sum_abs, sum_abs_err) /
static_cast<double>(div);
const double cond =
2.0 * gcpp::HWY_NAMESPACE::ReduceCascadedSums(df, sum_abs, sum_abs_err) /
static_cast<double>(div);
HWY_ASSERT(cond >= 0.0);
return cond;
}
Expand All @@ -124,10 +129,12 @@ HWY_MAYBE_UNUSED double ConditionNumber(const VT* HWY_RESTRICT v, size_t num) {
for (; i <= num - 2 * N; i += 2 * N) {
VF v0, v1;
Decompress2(df, packed_v, i, v0, v1);
UpdateCascadedSums(df, v0, sum, sum_err);
UpdateCascadedSums(df, v1, sum, sum_err);
UpdateCascadedSums(df, hn::Abs(v0), sum_abs, sum_abs_err);
UpdateCascadedSums(df, hn::Abs(v1), sum_abs, sum_abs_err);
gcpp::HWY_NAMESPACE::UpdateCascadedSums(df, v0, sum, sum_err);
gcpp::HWY_NAMESPACE::UpdateCascadedSums(df, v1, sum, sum_err);
gcpp::HWY_NAMESPACE::UpdateCascadedSums(df, hn::Abs(v0), sum_abs,
sum_abs_err);
gcpp::HWY_NAMESPACE::UpdateCascadedSums(df, hn::Abs(v1), sum_abs,
sum_abs_err);
}
}

Expand All @@ -140,15 +147,18 @@ HWY_MAYBE_UNUSED double ConditionNumber(const VT* HWY_RESTRICT v, size_t num) {
// 1..2 whole vectors, possibly zero-padded.
for (size_t padded_pos = 0; padded_pos < remaining; padded_pos += N) {
const VF v0 = hn::Load(df, padded_v + padded_pos);
UpdateCascadedSums(df, v0, sum, sum_err);
UpdateCascadedSums(df, hn::Abs(v0), sum_abs, sum_abs_err);
gcpp::HWY_NAMESPACE::UpdateCascadedSums(df, v0, sum, sum_err);
gcpp::HWY_NAMESPACE::UpdateCascadedSums(df, hn::Abs(v0), sum_abs,
sum_abs_err);
}
}

const float div = hwy::ScalarAbs(ReduceCascadedSums(df, sum, sum_err));
const float div =
hwy::ScalarAbs(gcpp::HWY_NAMESPACE::ReduceCascadedSums(df, sum, sum_err));
if (div == 0.0f) return hn::GetLane(hn::Inf(df));
const double cond = 2.0 * ReduceCascadedSums(df, sum_abs, sum_abs_err) /
static_cast<double>(div);
const double cond =
2.0 * gcpp::HWY_NAMESPACE::ReduceCascadedSums(df, sum_abs, sum_abs_err) /
static_cast<double>(div);
HWY_ASSERT(cond >= 0.0);
return cond;
}
Expand Down Expand Up @@ -326,16 +336,16 @@ struct DotKernelCompensated {
const VF v3, VF& sum0, VF& sum1, VF& sum2, VF& sum3,
VF& comp0, VF& comp1, VF& comp2, VF& comp3) const {
VF perr0, perr1, perr2, perr3;
const VF prod0 = TwoProducts(df, w0, v0, perr0);
const VF prod1 = TwoProducts(df, w1, v1, perr1);
const VF prod2 = TwoProducts(df, w2, v2, perr2);
const VF prod3 = TwoProducts(df, w3, v3, perr3);
const VF prod0 = gcpp::HWY_NAMESPACE::TwoProducts(df, w0, v0, perr0);
const VF prod1 = gcpp::HWY_NAMESPACE::TwoProducts(df, w1, v1, perr1);
const VF prod2 = gcpp::HWY_NAMESPACE::TwoProducts(df, w2, v2, perr2);
const VF prod3 = gcpp::HWY_NAMESPACE::TwoProducts(df, w3, v3, perr3);

VF serr0, serr1, serr2, serr3;
sum0 = TwoSums(df, prod0, sum0, serr0);
sum1 = TwoSums(df, prod1, sum1, serr1);
sum2 = TwoSums(df, prod2, sum2, serr2);
sum3 = TwoSums(df, prod3, sum3, serr3);
sum0 = gcpp::HWY_NAMESPACE::TwoSums(df, prod0, sum0, serr0);
sum1 = gcpp::HWY_NAMESPACE::TwoSums(df, prod1, sum1, serr1);
sum2 = gcpp::HWY_NAMESPACE::TwoSums(df, prod2, sum2, serr2);
sum3 = gcpp::HWY_NAMESPACE::TwoSums(df, prod3, sum3, serr3);

comp0 = hn::Add(comp0, hn::Add(perr0, serr0));
comp1 = hn::Add(comp1, hn::Add(perr1, serr1));
Expand All @@ -357,10 +367,10 @@ struct DotKernelCompensated {
const VS prod0 = hn::WidenMulPairwiseAdd(df, w0, v0);

VS serr0, serr1, serr2, serr3;
sum0 = TwoSums(df, prod0, sum0, serr0);
sum1 = TwoSums(df, prod1, sum1, serr1);
sum2 = TwoSums(df, prod2, sum2, serr2);
sum3 = TwoSums(df, prod3, sum3, serr3);
sum0 = gcpp::HWY_NAMESPACE::TwoSums(df, prod0, sum0, serr0);
sum1 = gcpp::HWY_NAMESPACE::TwoSums(df, prod1, sum1, serr1);
sum2 = gcpp::HWY_NAMESPACE::TwoSums(df, prod2, sum2, serr2);
sum3 = gcpp::HWY_NAMESPACE::TwoSums(df, prod3, sum3, serr3);

comp0 = hn::Add(comp0, serr0);
comp1 = hn::Add(comp1, serr1);
Expand All @@ -373,10 +383,10 @@ struct DotKernelCompensated {
HWY_INLINE void Update1(DF df, const VF w0, const VF v0, VF& sum0,
VF& comp0) const {
VF perr0;
const VF prod0 = TwoProducts(df, w0, v0, perr0);
const VF prod0 = gcpp::HWY_NAMESPACE::TwoProducts(df, w0, v0, perr0);

VF serr0;
sum0 = TwoSums(df, prod0, sum0, serr0);
sum0 = gcpp::HWY_NAMESPACE::TwoSums(df, prod0, sum0, serr0);

comp0 = hn::Add(comp0, hn::Add(perr0, serr0));
}
Expand All @@ -387,10 +397,10 @@ struct DotKernelCompensated {
HWY_INLINE void Update1(DRaw, const VR w0, const VR v0, VS& sum0,
VS& comp0) const {
const DS df;
const VS prod0 = WidenMulPairwiseAdd(df, w0, v0);
const VS prod0 = hn::WidenMulPairwiseAdd(df, w0, v0);

VS serr0;
sum0 = TwoSums(df, prod0, sum0, serr0);
sum0 = gcpp::HWY_NAMESPACE::TwoSums(df, prod0, sum0, serr0);

comp0 = hn::Add(comp0, serr0);
}
Expand All @@ -399,10 +409,10 @@ struct DotKernelCompensated {
HWY_INLINE float Reduce(DS df, VS& sum0, VS& sum1, VS& sum2, VS& sum3,
VS& comp0, VS& comp1, VS& comp2, VS& comp3) const {
// Reduction tree: sum of all accumulators by pairs, then across lanes.
AssimilateCascadedSums(df, sum1, comp1, sum0, comp0);
AssimilateCascadedSums(df, sum3, comp3, sum2, comp2);
AssimilateCascadedSums(df, sum2, comp2, sum0, comp0);
return ReduceCascadedSums(df, sum0, comp0);
gcpp::HWY_NAMESPACE::AssimilateCascadedSums(df, sum1, comp1, sum0, comp0);
gcpp::HWY_NAMESPACE::AssimilateCascadedSums(df, sum3, comp3, sum2, comp2);
gcpp::HWY_NAMESPACE::AssimilateCascadedSums(df, sum2, comp2, sum0, comp0);
return gcpp::HWY_NAMESPACE::ReduceCascadedSums(df, sum0, comp0);
}
};

Expand Down
Loading