@@ -125,16 +125,14 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
125
125
EBM_ASSERT (nullptr != apBins);
126
126
EBM_ASSERT (nullptr != apBinsEnd);
127
127
EBM_ASSERT (apBins < apBinsEnd); // if zero bins then we should have handled elsewhere
128
- EBM_ASSERT (1 <= cSlices);
128
+ EBM_ASSERT (2 <= cSlices); // if there were no cuts then we wouldn't satisfy the minimum gain
129
129
EBM_ASSERT (2 <= cBins);
130
130
EBM_ASSERT (cSlices <= cBins);
131
131
EBM_ASSERT (!bNominal || cSlices == cBins);
132
132
133
133
ErrorEbm error;
134
134
135
135
#ifndef NDEBUG
136
- auto * const pRootTreeNodeDebug = pBoosterShell->GetTreeNodesTemp <bHessian>();
137
- size_t cSamplesExpectedDebug = static_cast <size_t >(pRootTreeNodeDebug->GetBin ()->GetCountSamples ());
138
136
size_t cSamplesTotalDebug = 0 ;
139
137
#endif // NDEBUG
140
138
@@ -165,16 +163,10 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
165
163
166
164
const Bin<FloatMain, UIntMain, true , true , bHessian>* pMissingBin = nullptr ;
167
165
const Bin<FloatMain, UIntMain, true , true , bHessian>* pDregsBin = nullptr ;
168
- const Bin<FloatMain, UIntMain, true , true , bHessian>* const * ppBinCur = nullptr ;
169
- if (bNominal) {
170
- UIntSplit iSplit = 1 ;
171
- while (cSlices != iSplit) {
172
- pSplit[iSplit - 1 ] = iSplit;
173
- ++iSplit;
174
- }
175
- ppBinCur = reinterpret_cast <const Bin<FloatMain, UIntMain, true , true , bHessian>* const *>(apBins);
176
- } else {
166
+ const Bin<FloatMain, UIntMain, true , true , bHessian>* const * ppBinCur;
167
+ if (!bNominal) {
177
168
pUpdateScore = aUpdateScore;
169
+ ppBinCur = nullptr ;
178
170
179
171
if (bMissing) {
180
172
EBM_ASSERT (2 <= cSlices); // no cuts if there was only missing bin
@@ -187,9 +179,16 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
187
179
*pSplit = 1 ;
188
180
++pSplit;
189
181
190
- // pUpdateScore is overwritten later if bNominal
191
182
pUpdateScore += cScores;
192
183
}
184
+ } else {
185
+ UIntSplit iSplit = 1 ;
186
+ EBM_ASSERT (2 <= cSlices); // no cuts if there was only missing bin
187
+ do {
188
+ pSplit[iSplit - 1 ] = iSplit;
189
+ ++iSplit;
190
+ } while (cSlices != iSplit);
191
+ ppBinCur = reinterpret_cast <const Bin<FloatMain, UIntMain, true , true , bHessian>* const *>(apBins);
193
192
}
194
193
195
194
const size_t cBytesPerBin = GetBinSize<FloatMain, UIntMain>(true , true , bHessian, cScores);
@@ -256,12 +255,7 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
256
255
}
257
256
258
257
EBM_ASSERT (apBins <= ppBinLast);
259
- EBM_ASSERT (ppBinLast < apBins +
260
- (cBins -
261
- (nullptr != pMissingValueTreeNode ||
262
- bMissing && !bNominal && (TermBoostFlags_MissingSeparate & flags) ?
263
- size_t {1 } :
264
- size_t {0 })));
258
+ EBM_ASSERT (ppBinLast < apBinsEnd);
265
259
266
260
#ifndef NDEBUG
267
261
cSamplesTotalDebug += static_cast <size_t >(pTreeNode->GetBin ()->GetCountSamples ());
@@ -291,6 +285,7 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
291
285
} else {
292
286
++iEdge; // missing is at index 0 in the model, so we are offset by one
293
287
if (TermBoostFlags_MissingHigh & flags) {
288
+ // This might not be the missing bin, but if we keep assigning it, the last time will be missing.
294
289
pMissingBin = pTreeNode->GetBin ();
295
290
EBM_ASSERT (iEdge <= cBins + 1 );
296
291
if (bDone) {
@@ -386,7 +381,9 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
386
381
done:;
387
382
388
383
#ifndef NDEBUG
389
- EBM_ASSERT (cSamplesTotalDebug == cSamplesExpectedDebug);
384
+ const size_t cSamplesDebug =
385
+ static_cast <size_t >(pBoosterShell->GetTreeNodesTemp <bHessian>()->GetBin ()->GetCountSamples ());
386
+ EBM_ASSERT (cSamplesTotalDebug == cSamplesDebug);
390
387
EBM_ASSERT (bNominal || pUpdateScore == aUpdateScore + cScores * cSlices);
391
388
EBM_ASSERT (bNominal || pSplit == cSlices - 1 + pInnerTermUpdate->GetSplitPointer (iDimension));
392
389
@@ -401,14 +398,15 @@ done:;
401
398
402
399
if (nullptr != pDregsBin) {
403
400
EBM_ASSERT (bNominal);
401
+ EBM_ASSERT (nullptr != pDregsTreeNode);
404
402
405
403
std::sort (apBins, apBinsEnd);
406
404
407
405
const auto * const * ppBinSweep =
408
406
reinterpret_cast <const Bin<FloatMain, UIntMain, true , true , bHessian>* const *>(apBins);
409
407
410
- // for nominal, we would only skip the 0th missing bin if missing was assigned based on gain
411
- size_t iBin = nullptr == pMissingValueTreeNode ? size_t {0 } : size_t { 1 };
408
+ EBM_ASSERT ( nullptr == pMissingValueTreeNode);
409
+ size_t iBin = size_t {0 };
412
410
do {
413
411
const auto * const pBin = IndexBin (aBins, cBytesPerBin * iBin);
414
412
if (apBinsEnd != ppBinSweep && *ppBinSweep == pBin) {
@@ -436,6 +434,8 @@ done:;
436
434
}
437
435
++iBin;
438
436
} while (cBins != iBin);
437
+ } else {
438
+ EBM_ASSERT (nullptr == pDregsTreeNode);
439
439
}
440
440
441
441
if (nullptr != pMissingBin) {
@@ -459,6 +459,7 @@ done:;
459
459
} while (pGradientPairEnd != pGradientPair);
460
460
} else {
461
461
EBM_ASSERT (!bMissing || bNominal);
462
+ EBM_ASSERT (nullptr == pMissingValueTreeNode);
462
463
}
463
464
464
465
LOG_0 (Trace_Verbose, " Exited Flatten" );
0 commit comments