Skip to content

Commit 15a600d

Browse files
committed
minor changes to the single dimensional tree building flatten code that are mostly to improve debugging and asserts
1 parent ef742fd commit 15a600d

File tree

1 file changed

+23
-22
lines changed

1 file changed

+23
-22
lines changed

shared/libebm/PartitionOneDimensionalBoosting.cpp

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -125,16 +125,14 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
125125
EBM_ASSERT(nullptr != apBins);
126126
EBM_ASSERT(nullptr != apBinsEnd);
127127
EBM_ASSERT(apBins < apBinsEnd); // if zero bins then we should have handled elsewhere
128-
EBM_ASSERT(1 <= cSlices);
128+
EBM_ASSERT(2 <= cSlices); // if there were no cuts then we wouldn't satisfy the minimum gain
129129
EBM_ASSERT(2 <= cBins);
130130
EBM_ASSERT(cSlices <= cBins);
131131
EBM_ASSERT(!bNominal || cSlices == cBins);
132132

133133
ErrorEbm error;
134134

135135
#ifndef NDEBUG
136-
auto* const pRootTreeNodeDebug = pBoosterShell->GetTreeNodesTemp<bHessian>();
137-
size_t cSamplesExpectedDebug = static_cast<size_t>(pRootTreeNodeDebug->GetBin()->GetCountSamples());
138136
size_t cSamplesTotalDebug = 0;
139137
#endif // NDEBUG
140138

@@ -165,16 +163,10 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
165163

166164
const Bin<FloatMain, UIntMain, true, true, bHessian>* pMissingBin = nullptr;
167165
const Bin<FloatMain, UIntMain, true, true, bHessian>* pDregsBin = nullptr;
168-
const Bin<FloatMain, UIntMain, true, true, bHessian>* const* ppBinCur = nullptr;
169-
if(bNominal) {
170-
UIntSplit iSplit = 1;
171-
while(cSlices != iSplit) {
172-
pSplit[iSplit - 1] = iSplit;
173-
++iSplit;
174-
}
175-
ppBinCur = reinterpret_cast<const Bin<FloatMain, UIntMain, true, true, bHessian>* const*>(apBins);
176-
} else {
166+
const Bin<FloatMain, UIntMain, true, true, bHessian>* const* ppBinCur;
167+
if(!bNominal) {
177168
pUpdateScore = aUpdateScore;
169+
ppBinCur = nullptr;
178170

179171
if(bMissing) {
180172
EBM_ASSERT(2 <= cSlices); // no cuts if there was only missing bin
@@ -187,9 +179,16 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
187179
*pSplit = 1;
188180
++pSplit;
189181

190-
// pUpdateScore is overwritten later if bNominal
191182
pUpdateScore += cScores;
192183
}
184+
} else {
185+
UIntSplit iSplit = 1;
186+
EBM_ASSERT(2 <= cSlices); // no cuts if there was only missing bin
187+
do {
188+
pSplit[iSplit - 1] = iSplit;
189+
++iSplit;
190+
} while(cSlices != iSplit);
191+
ppBinCur = reinterpret_cast<const Bin<FloatMain, UIntMain, true, true, bHessian>* const*>(apBins);
193192
}
194193

195194
const size_t cBytesPerBin = GetBinSize<FloatMain, UIntMain>(true, true, bHessian, cScores);
@@ -256,12 +255,7 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
256255
}
257256

258257
EBM_ASSERT(apBins <= ppBinLast);
259-
EBM_ASSERT(ppBinLast < apBins +
260-
(cBins -
261-
(nullptr != pMissingValueTreeNode ||
262-
bMissing && !bNominal && (TermBoostFlags_MissingSeparate & flags) ?
263-
size_t{1} :
264-
size_t{0})));
258+
EBM_ASSERT(ppBinLast < apBinsEnd);
265259

266260
#ifndef NDEBUG
267261
cSamplesTotalDebug += static_cast<size_t>(pTreeNode->GetBin()->GetCountSamples());
@@ -291,6 +285,7 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
291285
} else {
292286
++iEdge; // missing is at index 0 in the model, so we are offset by one
293287
if(TermBoostFlags_MissingHigh & flags) {
288+
// This might not be the missing bin, but if we keep assigning it, the last time will be missing.
294289
pMissingBin = pTreeNode->GetBin();
295290
EBM_ASSERT(iEdge <= cBins + 1);
296291
if(bDone) {
@@ -386,7 +381,9 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
386381
done:;
387382

388383
#ifndef NDEBUG
389-
EBM_ASSERT(cSamplesTotalDebug == cSamplesExpectedDebug);
384+
const size_t cSamplesDebug =
385+
static_cast<size_t>(pBoosterShell->GetTreeNodesTemp<bHessian>()->GetBin()->GetCountSamples());
386+
EBM_ASSERT(cSamplesTotalDebug == cSamplesDebug);
390387
EBM_ASSERT(bNominal || pUpdateScore == aUpdateScore + cScores * cSlices);
391388
EBM_ASSERT(bNominal || pSplit == cSlices - 1 + pInnerTermUpdate->GetSplitPointer(iDimension));
392389

@@ -401,14 +398,15 @@ done:;
401398

402399
if(nullptr != pDregsBin) {
403400
EBM_ASSERT(bNominal);
401+
EBM_ASSERT(nullptr != pDregsTreeNode);
404402

405403
std::sort(apBins, apBinsEnd);
406404

407405
const auto* const* ppBinSweep =
408406
reinterpret_cast<const Bin<FloatMain, UIntMain, true, true, bHessian>* const*>(apBins);
409407

410-
// for nominal, we would only skip the 0th missing bin if missing was assigned based on gain
411-
size_t iBin = nullptr == pMissingValueTreeNode ? size_t{0} : size_t{1};
408+
EBM_ASSERT(nullptr == pMissingValueTreeNode);
409+
size_t iBin = size_t{0};
412410
do {
413411
const auto* const pBin = IndexBin(aBins, cBytesPerBin * iBin);
414412
if(apBinsEnd != ppBinSweep && *ppBinSweep == pBin) {
@@ -436,6 +434,8 @@ done:;
436434
}
437435
++iBin;
438436
} while(cBins != iBin);
437+
} else {
438+
EBM_ASSERT(nullptr == pDregsTreeNode);
439439
}
440440

441441
if(nullptr != pMissingBin) {
@@ -459,6 +459,7 @@ done:;
459459
} while(pGradientPairEnd != pGradientPair);
460460
} else {
461461
EBM_ASSERT(!bMissing || bNominal);
462+
EBM_ASSERT(nullptr == pMissingValueTreeNode);
462463
}
463464

464465
LOG_0(Trace_Verbose, "Exited Flatten");

0 commit comments

Comments
 (0)