From 4136fa207faa2d3f3baa300d613057787b484aa6 Mon Sep 17 00:00:00 2001 From: Frank Tang Date: Fri, 30 Apr 2021 19:19:24 -0700 Subject: [PATCH] ICU-21569 Propagate LSTM memory allocation issue --- icu4c/source/common/brkeng.cpp | 4 +- icu4c/source/common/brkeng.h | 8 +- icu4c/source/common/dictbe.cpp | 41 ++-- icu4c/source/common/dictbe.h | 250 ++++++++++++----------- icu4c/source/common/lstmbe.cpp | 115 +++++++---- icu4c/source/common/lstmbe.h | 4 +- icu4c/source/common/rbbi_cache.cpp | 2 +- icu4c/source/test/intltest/lstmbetst.cpp | 6 +- 8 files changed, 243 insertions(+), 187 deletions(-) diff --git a/icu4c/source/common/brkeng.cpp b/icu4c/source/common/brkeng.cpp index a4c88a4db6b..dbdd7839d94 100644 --- a/icu4c/source/common/brkeng.cpp +++ b/icu4c/source/common/brkeng.cpp @@ -78,7 +78,9 @@ int32_t UnhandledEngine::findBreaks( UText *text, int32_t /* startPos */, int32_t endPos, - UVector32 &/*foundBreaks*/ ) const { + UVector32 &/*foundBreaks*/, + UErrorCode &status) const { + if (U_FAILURE(status)) return 0; UChar32 c = utext_current32(text); while((int32_t)utext_getNativeIndex(text) < endPos && fHandled->contains(c)) { utext_next32(text); // TODO: recast loop to work with post-increment operations. diff --git a/icu4c/source/common/brkeng.h b/icu4c/source/common/brkeng.h index 155433b89a8..f6b64c83e25 100644 --- a/icu4c/source/common/brkeng.h +++ b/icu4c/source/common/brkeng.h @@ -68,12 +68,14 @@ class LanguageBreakEngine : public UMemory { * @param startPos The start of the run within the supplied text. * @param endPos The end of the run within the supplied text. * @param foundBreaks A Vector of int32_t to receive the breaks. + * @param status Information on any errors encountered. * @return The number of breaks found. */ virtual int32_t findBreaks( UText *text, int32_t startPos, int32_t endPos, - UVector32 &foundBreaks ) const = 0; + UVector32 &foundBreaks, + UErrorCode &status) const = 0; }; @@ -185,12 +187,14 @@ class UnhandledEngine : public LanguageBreakEngine { * @param startPos The start of the run within the supplied text. * @param endPos The end of the run within the supplied text. * @param foundBreaks An allocated C array of the breaks found, if any + * @param status Information on any errors encountered. * @return The number of breaks found. */ virtual int32_t findBreaks( UText *text, int32_t startPos, int32_t endPos, - UVector32 &foundBreaks ) const; + UVector32 &foundBreaks, + UErrorCode &status) const; /** *

Tell the engine to handle a particular character and break type.

diff --git a/icu4c/source/common/dictbe.cpp b/icu4c/source/common/dictbe.cpp index 44285755f3f..f9aae5bc9ed 100644 --- a/icu4c/source/common/dictbe.cpp +++ b/icu4c/source/common/dictbe.cpp @@ -47,7 +47,9 @@ int32_t DictionaryBreakEngine::findBreaks( UText *text, int32_t startPos, int32_t endPos, - UVector32 &foundBreaks ) const { + UVector32 &foundBreaks, + UErrorCode& status) const { + if (U_FAILURE(status)) return 0; (void)startPos; // TODO: remove this param? int32_t result = 0; @@ -66,7 +68,7 @@ DictionaryBreakEngine::findBreaks( UText *text, } rangeStart = start; rangeEnd = current; - result = divideUpDictionaryRange(text, rangeStart, rangeEnd, foundBreaks); + result = divideUpDictionaryRange(text, rangeStart, rangeEnd, foundBreaks, status); utext_setNativeIndex(text, current); return result; @@ -227,7 +229,9 @@ int32_t ThaiBreakEngine::divideUpDictionaryRange( UText *text, int32_t rangeStart, int32_t rangeEnd, - UVector32 &foundBreaks ) const { + UVector32 &foundBreaks, + UErrorCode& status) const { + if (U_FAILURE(status)) return 0; utext_setNativeIndex(text, rangeStart); utext_moveIndex32(text, THAI_MIN_WORD_SPAN); if (utext_getNativeIndex(text) >= rangeEnd) { @@ -240,7 +244,6 @@ ThaiBreakEngine::divideUpDictionaryRange( UText *text, int32_t cpWordLength = 0; // Word Length in Code Points. int32_t cuWordLength = 0; // Word length in code units (UText native indexing) int32_t current; - UErrorCode status = U_ZERO_ERROR; PossibleWord words[THAI_LOOKAHEAD]; utext_setNativeIndex(text, rangeStart); @@ -465,7 +468,9 @@ int32_t LaoBreakEngine::divideUpDictionaryRange( UText *text, int32_t rangeStart, int32_t rangeEnd, - UVector32 &foundBreaks ) const { + UVector32 &foundBreaks, + UErrorCode& status) const { + if (U_FAILURE(status)) return 0; if ((rangeEnd - rangeStart) < LAO_MIN_WORD_SPAN) { return 0; // Not enough characters for two words } @@ -474,11 +479,10 @@ LaoBreakEngine::divideUpDictionaryRange( UText *text, int32_t cpWordLength = 0; int32_t cuWordLength = 0; int32_t current; - UErrorCode status = U_ZERO_ERROR; PossibleWord words[LAO_LOOKAHEAD]; - + utext_setNativeIndex(text, rangeStart); - + while (U_SUCCESS(status) && (current = (int32_t)utext_getNativeIndex(text)) < rangeEnd) { cuWordLength = 0; cpWordLength = 0; @@ -657,7 +661,9 @@ int32_t BurmeseBreakEngine::divideUpDictionaryRange( UText *text, int32_t rangeStart, int32_t rangeEnd, - UVector32 &foundBreaks ) const { + UVector32 &foundBreaks, + UErrorCode& status ) const { + if (U_FAILURE(status)) return 0; if ((rangeEnd - rangeStart) < BURMESE_MIN_WORD_SPAN) { return 0; // Not enough characters for two words } @@ -666,11 +672,10 @@ BurmeseBreakEngine::divideUpDictionaryRange( UText *text, int32_t cpWordLength = 0; int32_t cuWordLength = 0; int32_t current; - UErrorCode status = U_ZERO_ERROR; PossibleWord words[BURMESE_LOOKAHEAD]; - + utext_setNativeIndex(text, rangeStart); - + while (U_SUCCESS(status) && (current = (int32_t)utext_getNativeIndex(text)) < rangeEnd) { cuWordLength = 0; cpWordLength = 0; @@ -861,7 +866,9 @@ int32_t KhmerBreakEngine::divideUpDictionaryRange( UText *text, int32_t rangeStart, int32_t rangeEnd, - UVector32 &foundBreaks ) const { + UVector32 &foundBreaks, + UErrorCode& status ) const { + if (U_FAILURE(status)) return 0; if ((rangeEnd - rangeStart) < KHMER_MIN_WORD_SPAN) { return 0; // Not enough characters for two words } @@ -870,7 +877,6 @@ KhmerBreakEngine::divideUpDictionaryRange( UText *text, int32_t cpWordLength = 0; int32_t cuWordLength = 0; int32_t current; - UErrorCode status = U_ZERO_ERROR; PossibleWord words[KHMER_LOOKAHEAD]; utext_setNativeIndex(text, rangeStart); @@ -1110,7 +1116,9 @@ int32_t CjkBreakEngine::divideUpDictionaryRange( UText *inText, int32_t rangeStart, int32_t rangeEnd, - UVector32 &foundBreaks ) const { + UVector32 &foundBreaks, + UErrorCode& status) const { + if (U_FAILURE(status)) return 0; if (rangeStart >= rangeEnd) { return 0; } @@ -1122,9 +1130,6 @@ CjkBreakEngine::divideUpDictionaryRange( UText *inText, // If NULL then mapping is 1:1 LocalPointer inputMap; - UErrorCode status = U_ZERO_ERROR; - - // if UText has the input string as one contiguous UTF-16 chunk if ((inText->providerProperties & utext_i32_flag(UTEXT_PROVIDER_STABLE_CHUNKS)) && inText->chunkNativeStart <= rangeStart && diff --git a/icu4c/source/common/dictbe.h b/icu4c/source/common/dictbe.h index 4ea676fc716..4adaaa4f09d 100644 --- a/icu4c/source/common/dictbe.h +++ b/icu4c/source/common/dictbe.h @@ -68,17 +68,19 @@ class DictionaryBreakEngine : public LanguageBreakEngine { *

Find any breaks within a run in the supplied text.

* * @param text A UText representing the text. The iterator is left at - * the end of the run of characters which the engine is capable of handling + * the end of the run of characters which the engine is capable of handling * that starts from the first character in the range. * @param startPos The start of the run within the supplied text. * @param endPos The end of the run within the supplied text. * @param foundBreaks vector of int32_t to receive the break positions + * @param status Information on any errors encountered. * @return The number of breaks found. */ virtual int32_t findBreaks( UText *text, int32_t startPos, int32_t endPos, - UVector32 &foundBreaks ) const; + UVector32 &foundBreaks, + UErrorCode& status ) const; protected: @@ -96,12 +98,14 @@ class DictionaryBreakEngine : public LanguageBreakEngine { * @param rangeStart The start of the range of dictionary characters * @param rangeEnd The end of the range of dictionary characters * @param foundBreaks Output of C array of int32_t break positions, or 0 + * @param status Information on any errors encountered. * @return The number of breaks found */ virtual int32_t divideUpDictionaryRange( UText *text, int32_t rangeStart, int32_t rangeEnd, - UVector32 &foundBreaks ) const = 0; + UVector32 &foundBreaks, + UErrorCode& status) const = 0; }; @@ -153,12 +157,14 @@ class ThaiBreakEngine : public DictionaryBreakEngine { * @param rangeStart The start of the range of dictionary characters * @param rangeEnd The end of the range of dictionary characters * @param foundBreaks Output of C array of int32_t break positions, or 0 + * @param status Information on any errors encountered. * @return The number of breaks found */ virtual int32_t divideUpDictionaryRange( UText *text, int32_t rangeStart, int32_t rangeEnd, - UVector32 &foundBreaks ) const; + UVector32 &foundBreaks, + UErrorCode& status) const; }; @@ -209,127 +215,133 @@ class LaoBreakEngine : public DictionaryBreakEngine { * @param rangeStart The start of the range of dictionary characters * @param rangeEnd The end of the range of dictionary characters * @param foundBreaks Output of C array of int32_t break positions, or 0 + * @param status Information on any errors encountered. * @return The number of breaks found */ virtual int32_t divideUpDictionaryRange( UText *text, int32_t rangeStart, int32_t rangeEnd, - UVector32 &foundBreaks ) const; + UVector32 &foundBreaks, + UErrorCode& status) const; + +}; + +/******************************************************************* + * BurmeseBreakEngine + */ + +/** + *

BurmeseBreakEngine is a kind of DictionaryBreakEngine that uses a + * DictionaryMatcher and heuristics to determine Burmese-specific breaks.

+ * + *

After it is constructed a BurmeseBreakEngine may be shared between + * threads without synchronization.

+ */ +class BurmeseBreakEngine : public DictionaryBreakEngine { + private: + /** + * The set of characters handled by this engine + * @internal + */ + + UnicodeSet fBurmeseWordSet; + UnicodeSet fEndWordSet; + UnicodeSet fBeginWordSet; + UnicodeSet fMarkSet; + DictionaryMatcher *fDictionary; + + public: + + /** + *

Default constructor.

+ * + * @param adoptDictionary A DictionaryMatcher to adopt. Deleted when the + * engine is deleted. + */ + BurmeseBreakEngine(DictionaryMatcher *adoptDictionary, UErrorCode &status); + + /** + *

Virtual destructor.

+ */ + virtual ~BurmeseBreakEngine(); + + protected: + /** + *

Divide up a range of known dictionary characters.

+ * + * @param text A UText representing the text + * @param rangeStart The start of the range of dictionary characters + * @param rangeEnd The end of the range of dictionary characters + * @param foundBreaks Output of C array of int32_t break positions, or 0 + * @param status Information on any errors encountered. + * @return The number of breaks found + */ + virtual int32_t divideUpDictionaryRange( UText *text, + int32_t rangeStart, + int32_t rangeEnd, + UVector32 &foundBreaks, + UErrorCode& status) const; + +}; + +/******************************************************************* + * KhmerBreakEngine + */ + +/** + *

KhmerBreakEngine is a kind of DictionaryBreakEngine that uses a + * DictionaryMatcher and heuristics to determine Khmer-specific breaks.

+ * + *

After it is constructed a KhmerBreakEngine may be shared between + * threads without synchronization.

+ */ +class KhmerBreakEngine : public DictionaryBreakEngine { + private: + /** + * The set of characters handled by this engine + * @internal + */ + + UnicodeSet fKhmerWordSet; + UnicodeSet fEndWordSet; + UnicodeSet fBeginWordSet; + UnicodeSet fMarkSet; + DictionaryMatcher *fDictionary; + + public: + + /** + *

Default constructor.

+ * + * @param adoptDictionary A DictionaryMatcher to adopt. Deleted when the + * engine is deleted. + */ + KhmerBreakEngine(DictionaryMatcher *adoptDictionary, UErrorCode &status); + + /** + *

Virtual destructor.

+ */ + virtual ~KhmerBreakEngine(); + + protected: + /** + *

Divide up a range of known dictionary characters.

+ * + * @param text A UText representing the text + * @param rangeStart The start of the range of dictionary characters + * @param rangeEnd The end of the range of dictionary characters + * @param foundBreaks Output of C array of int32_t break positions, or 0 + * @param status Information on any errors encountered. + * @return The number of breaks found + */ + virtual int32_t divideUpDictionaryRange( UText *text, + int32_t rangeStart, + int32_t rangeEnd, + UVector32 &foundBreaks, + UErrorCode& status) const; }; -/******************************************************************* - * BurmeseBreakEngine - */ - -/** - *

BurmeseBreakEngine is a kind of DictionaryBreakEngine that uses a - * DictionaryMatcher and heuristics to determine Burmese-specific breaks.

- * - *

After it is constructed a BurmeseBreakEngine may be shared between - * threads without synchronization.

- */ -class BurmeseBreakEngine : public DictionaryBreakEngine { - private: - /** - * The set of characters handled by this engine - * @internal - */ - - UnicodeSet fBurmeseWordSet; - UnicodeSet fEndWordSet; - UnicodeSet fBeginWordSet; - UnicodeSet fMarkSet; - DictionaryMatcher *fDictionary; - - public: - - /** - *

Default constructor.

- * - * @param adoptDictionary A DictionaryMatcher to adopt. Deleted when the - * engine is deleted. - */ - BurmeseBreakEngine(DictionaryMatcher *adoptDictionary, UErrorCode &status); - - /** - *

Virtual destructor.

- */ - virtual ~BurmeseBreakEngine(); - - protected: - /** - *

Divide up a range of known dictionary characters.

- * - * @param text A UText representing the text - * @param rangeStart The start of the range of dictionary characters - * @param rangeEnd The end of the range of dictionary characters - * @param foundBreaks Output of C array of int32_t break positions, or 0 - * @return The number of breaks found - */ - virtual int32_t divideUpDictionaryRange( UText *text, - int32_t rangeStart, - int32_t rangeEnd, - UVector32 &foundBreaks ) const; - -}; - -/******************************************************************* - * KhmerBreakEngine - */ - -/** - *

KhmerBreakEngine is a kind of DictionaryBreakEngine that uses a - * DictionaryMatcher and heuristics to determine Khmer-specific breaks.

- * - *

After it is constructed a KhmerBreakEngine may be shared between - * threads without synchronization.

- */ -class KhmerBreakEngine : public DictionaryBreakEngine { - private: - /** - * The set of characters handled by this engine - * @internal - */ - - UnicodeSet fKhmerWordSet; - UnicodeSet fEndWordSet; - UnicodeSet fBeginWordSet; - UnicodeSet fMarkSet; - DictionaryMatcher *fDictionary; - - public: - - /** - *

Default constructor.

- * - * @param adoptDictionary A DictionaryMatcher to adopt. Deleted when the - * engine is deleted. - */ - KhmerBreakEngine(DictionaryMatcher *adoptDictionary, UErrorCode &status); - - /** - *

Virtual destructor.

- */ - virtual ~KhmerBreakEngine(); - - protected: - /** - *

Divide up a range of known dictionary characters.

- * - * @param text A UText representing the text - * @param rangeStart The start of the range of dictionary characters - * @param rangeEnd The end of the range of dictionary characters - * @param foundBreaks Output of C array of int32_t break positions, or 0 - * @return The number of breaks found - */ - virtual int32_t divideUpDictionaryRange( UText *text, - int32_t rangeStart, - int32_t rangeEnd, - UVector32 &foundBreaks ) const; - -}; - #if !UCONFIG_NO_NORMALIZATION /******************************************************************* @@ -385,12 +397,14 @@ class CjkBreakEngine : public DictionaryBreakEngine { * @param rangeStart The start of the range of dictionary characters * @param rangeEnd The end of the range of dictionary characters * @param foundBreaks Output of C array of int32_t break positions, or 0 + * @param status Information on any errors encountered. * @return The number of breaks found */ virtual int32_t divideUpDictionaryRange( UText *text, int32_t rangeStart, int32_t rangeEnd, - UVector32 &foundBreaks ) const; + UVector32 &foundBreaks, + UErrorCode& status) const; }; diff --git a/icu4c/source/common/lstmbe.cpp b/icu4c/source/common/lstmbe.cpp index be7eee624e7..a9123e9d494 100644 --- a/icu4c/source/common/lstmbe.cpp +++ b/icu4c/source/common/lstmbe.cpp @@ -162,10 +162,16 @@ ConstArray2D::~ConstArray2D() class Array1D : public ReadArray1D { public: Array1D() : memory_(nullptr), data_(nullptr), d1_(0) {} - Array1D(int32_t d1) + Array1D(int32_t d1, UErrorCode &status) : memory_(uprv_malloc(d1 * sizeof(float))), data_((float*)memory_), d1_(d1) { - clear(); + if (U_SUCCESS(status)) { + if (memory_ == nullptr) { + status = U_MEMORY_ALLOCATION_ERROR; + return; + } + clear(); + } } virtual ~Array1D(); @@ -278,10 +284,16 @@ Array1D::~Array1D() class Array2D : public ReadArray2D { public: Array2D() : memory_(nullptr), data_(nullptr), d1_(0), d2_(0) {} - Array2D(int32_t d1, int32_t d2) + Array2D(int32_t d1, int32_t d2, UErrorCode &status) : memory_(uprv_malloc(d1 * d2 * sizeof(float))), data_((float*)memory_), d1_(d1), d2_(d2) { - clear(); + if (U_SUCCESS(status)) { + if (memory_ == nullptr) { + status = U_MEMORY_ALLOCATION_ERROR; + return; + } + clear(); + } } virtual ~Array2D(); @@ -366,8 +378,10 @@ LSTMData::LSTMData(UResourceBundle* rb, UErrorCode &status) int32_t embedding_size = ures_getInt(embeddings_res.getAlias(), &status); LocalUResourceBundlePointer hunits_res( ures_getByKey(rb, "hunits", nullptr, &status)); + if (U_FAILURE(status)) return; int32_t hunits = ures_getInt(hunits_res.getAlias(), &status); const UChar* type = ures_getStringByKey(rb, "type", nullptr, &status); + if (U_FAILURE(status)) return; if (u_strCompare(type, -1, u"codepoints", -1, false) == 0) { fType = CODE_POINTS; } else if (u_strCompare(type, -1, u"graphclust", -1, false) == 0) { @@ -375,15 +389,15 @@ LSTMData::LSTMData(UResourceBundle* rb, UErrorCode &status) } fName = ures_getStringByKey(rb, "model", nullptr, &status); fDataRes = ures_getByKey(rb, "data", nullptr, &status); + if (U_FAILURE(status)) return; int32_t data_len = 0; const int32_t* data = ures_getIntVector(fDataRes, &data_len, &status); + if (U_FAILURE(status)) return; LocalUResourceBundlePointer fDictRes( ures_getByKey(rb, "dict", nullptr, &status)); int32_t num_index = ures_getSize(fDictRes.getAlias()); fDict = uhash_open(uhash_hashUChars, uhash_compareUChars, nullptr, &status); - if (U_FAILURE(status)) { - return; - } + if (U_FAILURE(status)) return; ures_resetIterator(fDictRes.getAlias()); int32_t idx = 0; @@ -391,10 +405,9 @@ LSTMData::LSTMData(UResourceBundle* rb, UErrorCode &status) while(ures_hasNext(fDictRes.getAlias())) { const char *tempKey = nullptr; const UChar* str = ures_getNextString(fDictRes.getAlias(), nullptr, &tempKey, &status); + if (U_FAILURE(status)) return; uhash_putiAllowZero(fDict, (void*)str, idx++, &status); - if (U_FAILURE(status)) { - return; - } + if (U_FAILURE(status)) return; #ifdef LSTM_VECTORIZER_DEBUG printf("Assign ["); while (*str != 0x0000) { @@ -495,6 +508,7 @@ void CodePointsVectorizer::vectorize( { if (offsets.ensureCapacity(endPos - startPos, status) && indices.ensureCapacity(endPos - startPos, status)) { + if (U_FAILURE(status)) return; utext_setNativeIndex(text, startPos); int32_t current; UChar str[2] = {0, 0}; @@ -533,21 +547,16 @@ void GraphemeClusterVectorizer::vectorize( UText *text, int32_t startPos, int32_t endPos, UVector32 &offsets, UVector32 &indices, UErrorCode &status) const { - if (U_FAILURE(status)) { - return; - } + if (U_FAILURE(status)) return; if (!offsets.ensureCapacity(endPos - startPos, status) || !indices.ensureCapacity(endPos - startPos, status)) { return; } + if (U_FAILURE(status)) return; LocalPointer graphemeIter(BreakIterator::createCharacterInstance(Locale(), status)); - if (U_FAILURE(status)) { - return; - } + if (U_FAILURE(status)) return; graphemeIter->setText(text, status); - if (U_FAILURE(status)) { - return; - } + if (U_FAILURE(status)) return; if (startPos != 0) { graphemeIter->preceding(startPos); @@ -561,11 +570,10 @@ void GraphemeClusterVectorizer::vectorize( } if (current > startPos) { utext_extract(text, last, current, str, MAX_GRAPHEME_CLSTER_LENTH, &status); - if (U_FAILURE(status)) { - break; - } + if (U_FAILURE(status)) return; offsets.addElement(last, status); indices.addElement(stringToIndex(str), status); + if (U_FAILURE(status)) return; } last = current; } @@ -583,13 +591,19 @@ void GraphemeClusterVectorizer::vectorize( // https://en.wikipedia.org/wiki/Long_short-term_memory#LSTM_with_a_forget_gate void compute( const ReadArray2D& W, const ReadArray2D& U, const ReadArray1D& b, - const ReadArray1D& x, Array1D& h, Array1D& c) + const ReadArray1D& x, Array1D& h, Array1D& c, + UErrorCode &status) { + if (U_FAILURE(status)) return; // ifco = x * W + h * U + b - Array1D ifco(b.d1()); - ifco.dotProduct(x, W) - .add(Array1D(b.d1()).dotProduct(h, U)) - .add(b); + Array1D ifco(b.d1(), status); + { + Array1D hU(b.d1(), status); + if (U_FAILURE(status)) return; + ifco.dotProduct(x, W) + .add(hU.dotProduct(h, U)) + .add(b); + } // delocate hU int32_t hunits = b.d1() / 4; ifco.slice(0*hunits, hunits).sigmoid(); // i: sigmod @@ -597,10 +611,14 @@ void compute( ifco.slice(2*hunits, hunits).tanh(); // c_: tanh ifco.slice(3*hunits, hunits).sigmoid(); // o: sigmod - c.hadamardProduct(ifco.slice(hunits, hunits)) - .add(Array1D(c.d1()) - .assign(ifco.slice(0, hunits)) - .hadamardProduct(ifco.slice(2*hunits, hunits))); + { + Array1D ic(c.d1(), status); + if (U_FAILURE(status)) return; + c.hadamardProduct(ifco.slice(hunits, hunits)) + .add(ic + .assign(ifco.slice(0, hunits)) + .hadamardProduct(ifco.slice(2*hunits, hunits))); + } h.assign(c) .tanh() @@ -614,10 +632,12 @@ static const int32_t MIN_WORD = 2; static const int32_t MIN_WORD_SPAN = MIN_WORD * 2; int32_t -LSTMBreakEngine::divideUpDictionaryRange(UText *text, - int32_t startPos, - int32_t endPos, - UVector32 &foundBreaks) const { +LSTMBreakEngine::divideUpDictionaryRange( UText *text, + int32_t startPos, + int32_t endPos, + UVector32 &foundBreaks, + UErrorCode& status) const { + if (U_FAILURE(status)) return 0; int32_t beginFoundBreakSize = foundBreaks.size(); utext_setNativeIndex(text, startPos); utext_moveIndex32(text, MIN_WORD_SPAN); @@ -625,11 +645,12 @@ LSTMBreakEngine::divideUpDictionaryRange(UText *text, return 0; // Not enough characters for two words } utext_setNativeIndex(text, startPos); - UErrorCode status = U_ZERO_ERROR; UVector32 offsets(status); UVector32 indices(status); + if (U_FAILURE(status)) return 0; fVectorizer->vectorize(text, startPos, endPos, offsets, indices, status); + if (U_FAILURE(status)) return 0; int32_t* offsetsBuf = offsets.getBuffer(); int32_t* indicesBuf = indices.getBuffer(); @@ -640,12 +661,13 @@ LSTMBreakEngine::divideUpDictionaryRange(UText *text, // Python or ICU4X implementation. We first perform the Backward LSTM // and then merge the iteration of the forward LSTM and the output layer // together because we only neetdto remember the h[t-1] for Forward LSTM. - Array1D c(hunits); + Array1D c(hunits, status); // TODO: limit size of hBackward. If input_seq_len is too big, we could // run out of memory. // Backward LSTM - Array2D hBackward(input_seq_len, hunits); + Array2D hBackward(input_seq_len, hunits, status); + if (U_FAILURE(status)) return 0; for (int32_t i = input_seq_len - 1; i >= 0; i--) { Array1D hRow = hBackward.row(i); if (i != input_seq_len - 1) { @@ -660,13 +682,15 @@ LSTMBreakEngine::divideUpDictionaryRange(UText *text, #endif // LSTM_DEBUG compute(fData->fBackwardW, fData->fBackwardU, fData->fBackwardB, fData->fEmbedding.row(indicesBuf[i]), - hRow, c); + hRow, c, status); + if (U_FAILURE(status)) return 0; } - Array1D logp(4); + Array1D logp(4, status); // Allocate fbRow and slice the internal array in two. - Array1D fbRow(2 * hunits); + Array1D fbRow(2 * hunits, status); + if (U_FAILURE(status)) return 0; Array1D forwardRow = fbRow.slice(0, hunits); // point to first half of data in fbRow. Array1D backwardRow = fbRow.slice(hunits, hunits); // point to second half of data n fbRow. @@ -683,7 +707,8 @@ LSTMBreakEngine::divideUpDictionaryRange(UText *text, // of fbRow. compute(fData->fForwardW, fData->fForwardU, fData->fForwardB, fData->fEmbedding.row(indicesBuf[i]), - forwardRow, c); + forwardRow, c, status); + if (U_FAILURE(status)) return 0; // assign the data from hBackward.row(i) to second half of fbRowa. backwardRow.assign(hBackward.row(i)); @@ -702,6 +727,7 @@ LSTMBreakEngine::divideUpDictionaryRange(UText *text, if (current == BEGIN || current == SINGLE) { if (i != 0) { foundBreaks.addElement(offsetsBuf[i], status); + if (U_FAILURE(status)) return 0; } } } @@ -759,14 +785,13 @@ U_CAPI const LSTMData* U_EXPORT2 CreateLSTMDataForScript(UScriptCode script, UEr return nullptr; } UnicodeString name = defaultLSTM(script, status); + if (U_FAILURE(status)) return nullptr; CharString namebuf; namebuf.appendInvariantChars(name, status).truncate(namebuf.lastIndexOf('.')); LocalUResourceBundlePointer rb( ures_openDirect(U_ICUDATA_BRKITR, namebuf.data(), &status)); - if (U_FAILURE(status)) { - return nullptr; - } + if (U_FAILURE(status)) return nullptr; return CreateLSTMData(rb.getAlias(), status); } diff --git a/icu4c/source/common/lstmbe.h b/icu4c/source/common/lstmbe.h index 38c75d6db98..84015e3a1fc 100644 --- a/icu4c/source/common/lstmbe.h +++ b/icu4c/source/common/lstmbe.h @@ -55,12 +55,14 @@ protected: * @param rangeStart The start of the range of dictionary characters * @param rangeEnd The end of the range of dictionary characters * @param foundBreaks Output of C array of int32_t break positions, or 0 + * @param status Information on any errors encountered. * @return The number of breaks found */ virtual int32_t divideUpDictionaryRange(UText *text, int32_t rangeStart, int32_t rangeEnd, - UVector32 &foundBreaks ) const; + UVector32 &foundBreaks, + UErrorCode& status) const; private: const LSTMData* fData; const Vectorizer* fVectorizer; diff --git a/icu4c/source/common/rbbi_cache.cpp b/icu4c/source/common/rbbi_cache.cpp index 44f19d86973..f3a89fdccd5 100644 --- a/icu4c/source/common/rbbi_cache.cpp +++ b/icu4c/source/common/rbbi_cache.cpp @@ -163,7 +163,7 @@ void RuleBasedBreakIterator::DictionaryCache::populateDictionary(int32_t startPo // Ask the language object if there are any breaks. It will add them to the cache and // leave the text pointer on the other side of its range, ready to search for the next one. if (lbe != NULL) { - foundBreakCount += lbe->findBreaks(text, rangeStart, rangeEnd, fBreaks); + foundBreakCount += lbe->findBreaks(text, rangeStart, rangeEnd, fBreaks, status); } // Reload the loop variables for the next go-round diff --git a/icu4c/source/test/intltest/lstmbetst.cpp b/icu4c/source/test/intltest/lstmbetst.cpp index f53421f9345..7281a37e9fb 100644 --- a/icu4c/source/test/intltest/lstmbetst.cpp +++ b/icu4c/source/test/intltest/lstmbetst.cpp @@ -151,7 +151,11 @@ void LSTMBETest::runTestFromFile(const char* filename) { dataerrln("%s:%d Error %s Could not allocate UVextor32", __FILE__, __LINE__, u_errorName(status)); return; } - engine->findBreaks(&ut, 0, value.length(), actual); + engine->findBreaks(&ut, 0, value.length(), actual, status); + if (U_FAILURE(status)) { + dataerrln("%s:%d Error %s findBreaks failed", __FILE__, __LINE__, u_errorName(status)); + return; + } utext_close(&ut); for (int32_t i = 0; i < actual.size(); i++) { ss << actual.elementAti(i) << ", ";