ICU-21569 Propagate LSTM memory allocation issue

This commit is contained in:
Frank Tang 2021-04-30 19:19:24 -07:00 committed by Frank Yung-Fong Tang
parent 512290fd23
commit 4136fa207f
8 changed files with 243 additions and 187 deletions

View file

@ -78,7 +78,9 @@ int32_t
UnhandledEngine::findBreaks( UText *text,
int32_t /* startPos */,
int32_t endPos,
UVector32 &/*foundBreaks*/ ) const {
UVector32 &/*foundBreaks*/,
UErrorCode &status) const {
if (U_FAILURE(status)) return 0;
UChar32 c = utext_current32(text);
while((int32_t)utext_getNativeIndex(text) < endPos && fHandled->contains(c)) {
utext_next32(text); // TODO: recast loop to work with post-increment operations.

View file

@ -68,12 +68,14 @@ class LanguageBreakEngine : public UMemory {
* @param startPos The start of the run within the supplied text.
* @param endPos The end of the run within the supplied text.
* @param foundBreaks A Vector of int32_t to receive the breaks.
* @param status Information on any errors encountered.
* @return The number of breaks found.
*/
virtual int32_t findBreaks( UText *text,
int32_t startPos,
int32_t endPos,
UVector32 &foundBreaks ) const = 0;
UVector32 &foundBreaks,
UErrorCode &status) const = 0;
};
@ -185,12 +187,14 @@ class UnhandledEngine : public LanguageBreakEngine {
* @param startPos The start of the run within the supplied text.
* @param endPos The end of the run within the supplied text.
* @param foundBreaks An allocated C array of the breaks found, if any
* @param status Information on any errors encountered.
* @return The number of breaks found.
*/
virtual int32_t findBreaks( UText *text,
int32_t startPos,
int32_t endPos,
UVector32 &foundBreaks ) const;
UVector32 &foundBreaks,
UErrorCode &status) const;
/**
* <p>Tell the engine to handle a particular character and break type.</p>

View file

@ -47,7 +47,9 @@ int32_t
DictionaryBreakEngine::findBreaks( UText *text,
int32_t startPos,
int32_t endPos,
UVector32 &foundBreaks ) const {
UVector32 &foundBreaks,
UErrorCode& status) const {
if (U_FAILURE(status)) return 0;
(void)startPos; // TODO: remove this param?
int32_t result = 0;
@ -66,7 +68,7 @@ DictionaryBreakEngine::findBreaks( UText *text,
}
rangeStart = start;
rangeEnd = current;
result = divideUpDictionaryRange(text, rangeStart, rangeEnd, foundBreaks);
result = divideUpDictionaryRange(text, rangeStart, rangeEnd, foundBreaks, status);
utext_setNativeIndex(text, current);
return result;
@ -227,7 +229,9 @@ int32_t
ThaiBreakEngine::divideUpDictionaryRange( UText *text,
int32_t rangeStart,
int32_t rangeEnd,
UVector32 &foundBreaks ) const {
UVector32 &foundBreaks,
UErrorCode& status) const {
if (U_FAILURE(status)) return 0;
utext_setNativeIndex(text, rangeStart);
utext_moveIndex32(text, THAI_MIN_WORD_SPAN);
if (utext_getNativeIndex(text) >= rangeEnd) {
@ -240,7 +244,6 @@ ThaiBreakEngine::divideUpDictionaryRange( UText *text,
int32_t cpWordLength = 0; // Word Length in Code Points.
int32_t cuWordLength = 0; // Word length in code units (UText native indexing)
int32_t current;
UErrorCode status = U_ZERO_ERROR;
PossibleWord words[THAI_LOOKAHEAD];
utext_setNativeIndex(text, rangeStart);
@ -465,7 +468,9 @@ int32_t
LaoBreakEngine::divideUpDictionaryRange( UText *text,
int32_t rangeStart,
int32_t rangeEnd,
UVector32 &foundBreaks ) const {
UVector32 &foundBreaks,
UErrorCode& status) const {
if (U_FAILURE(status)) return 0;
if ((rangeEnd - rangeStart) < LAO_MIN_WORD_SPAN) {
return 0; // Not enough characters for two words
}
@ -474,11 +479,10 @@ LaoBreakEngine::divideUpDictionaryRange( UText *text,
int32_t cpWordLength = 0;
int32_t cuWordLength = 0;
int32_t current;
UErrorCode status = U_ZERO_ERROR;
PossibleWord words[LAO_LOOKAHEAD];
utext_setNativeIndex(text, rangeStart);
while (U_SUCCESS(status) && (current = (int32_t)utext_getNativeIndex(text)) < rangeEnd) {
cuWordLength = 0;
cpWordLength = 0;
@ -657,7 +661,9 @@ int32_t
BurmeseBreakEngine::divideUpDictionaryRange( UText *text,
int32_t rangeStart,
int32_t rangeEnd,
UVector32 &foundBreaks ) const {
UVector32 &foundBreaks,
UErrorCode& status ) const {
if (U_FAILURE(status)) return 0;
if ((rangeEnd - rangeStart) < BURMESE_MIN_WORD_SPAN) {
return 0; // Not enough characters for two words
}
@ -666,11 +672,10 @@ BurmeseBreakEngine::divideUpDictionaryRange( UText *text,
int32_t cpWordLength = 0;
int32_t cuWordLength = 0;
int32_t current;
UErrorCode status = U_ZERO_ERROR;
PossibleWord words[BURMESE_LOOKAHEAD];
utext_setNativeIndex(text, rangeStart);
while (U_SUCCESS(status) && (current = (int32_t)utext_getNativeIndex(text)) < rangeEnd) {
cuWordLength = 0;
cpWordLength = 0;
@ -861,7 +866,9 @@ int32_t
KhmerBreakEngine::divideUpDictionaryRange( UText *text,
int32_t rangeStart,
int32_t rangeEnd,
UVector32 &foundBreaks ) const {
UVector32 &foundBreaks,
UErrorCode& status ) const {
if (U_FAILURE(status)) return 0;
if ((rangeEnd - rangeStart) < KHMER_MIN_WORD_SPAN) {
return 0; // Not enough characters for two words
}
@ -870,7 +877,6 @@ KhmerBreakEngine::divideUpDictionaryRange( UText *text,
int32_t cpWordLength = 0;
int32_t cuWordLength = 0;
int32_t current;
UErrorCode status = U_ZERO_ERROR;
PossibleWord words[KHMER_LOOKAHEAD];
utext_setNativeIndex(text, rangeStart);
@ -1110,7 +1116,9 @@ int32_t
CjkBreakEngine::divideUpDictionaryRange( UText *inText,
int32_t rangeStart,
int32_t rangeEnd,
UVector32 &foundBreaks ) const {
UVector32 &foundBreaks,
UErrorCode& status) const {
if (U_FAILURE(status)) return 0;
if (rangeStart >= rangeEnd) {
return 0;
}
@ -1122,9 +1130,6 @@ CjkBreakEngine::divideUpDictionaryRange( UText *inText,
// If NULL then mapping is 1:1
LocalPointer<UVector32> inputMap;
UErrorCode status = U_ZERO_ERROR;
// if UText has the input string as one contiguous UTF-16 chunk
if ((inText->providerProperties & utext_i32_flag(UTEXT_PROVIDER_STABLE_CHUNKS)) &&
inText->chunkNativeStart <= rangeStart &&

View file

@ -68,17 +68,19 @@ class DictionaryBreakEngine : public LanguageBreakEngine {
* <p>Find any breaks within a run in the supplied text.</p>
*
* @param text A UText representing the text. The iterator is left at
* the end of the run of characters which the engine is capable of handling
* the end of the run of characters which the engine is capable of handling
* that starts from the first character in the range.
* @param startPos The start of the run within the supplied text.
* @param endPos The end of the run within the supplied text.
* @param foundBreaks vector of int32_t to receive the break positions
* @param status Information on any errors encountered.
* @return The number of breaks found.
*/
virtual int32_t findBreaks( UText *text,
int32_t startPos,
int32_t endPos,
UVector32 &foundBreaks ) const;
UVector32 &foundBreaks,
UErrorCode& status ) const;
protected:
@ -96,12 +98,14 @@ class DictionaryBreakEngine : public LanguageBreakEngine {
* @param rangeStart The start of the range of dictionary characters
* @param rangeEnd The end of the range of dictionary characters
* @param foundBreaks Output of C array of int32_t break positions, or 0
* @param status Information on any errors encountered.
* @return The number of breaks found
*/
virtual int32_t divideUpDictionaryRange( UText *text,
int32_t rangeStart,
int32_t rangeEnd,
UVector32 &foundBreaks ) const = 0;
UVector32 &foundBreaks,
UErrorCode& status) const = 0;
};
@ -153,12 +157,14 @@ class ThaiBreakEngine : public DictionaryBreakEngine {
* @param rangeStart The start of the range of dictionary characters
* @param rangeEnd The end of the range of dictionary characters
* @param foundBreaks Output of C array of int32_t break positions, or 0
* @param status Information on any errors encountered.
* @return The number of breaks found
*/
virtual int32_t divideUpDictionaryRange( UText *text,
int32_t rangeStart,
int32_t rangeEnd,
UVector32 &foundBreaks ) const;
UVector32 &foundBreaks,
UErrorCode& status) const;
};
@ -209,127 +215,133 @@ class LaoBreakEngine : public DictionaryBreakEngine {
* @param rangeStart The start of the range of dictionary characters
* @param rangeEnd The end of the range of dictionary characters
* @param foundBreaks Output of C array of int32_t break positions, or 0
* @param status Information on any errors encountered.
* @return The number of breaks found
*/
virtual int32_t divideUpDictionaryRange( UText *text,
int32_t rangeStart,
int32_t rangeEnd,
UVector32 &foundBreaks ) const;
UVector32 &foundBreaks,
UErrorCode& status) const;
};
/*******************************************************************
* BurmeseBreakEngine
*/
/**
* <p>BurmeseBreakEngine is a kind of DictionaryBreakEngine that uses a
* DictionaryMatcher and heuristics to determine Burmese-specific breaks.</p>
*
* <p>After it is constructed a BurmeseBreakEngine may be shared between
* threads without synchronization.</p>
*/
class BurmeseBreakEngine : public DictionaryBreakEngine {
private:
/**
* The set of characters handled by this engine
* @internal
*/
UnicodeSet fBurmeseWordSet;
UnicodeSet fEndWordSet;
UnicodeSet fBeginWordSet;
UnicodeSet fMarkSet;
DictionaryMatcher *fDictionary;
public:
/**
* <p>Default constructor.</p>
*
* @param adoptDictionary A DictionaryMatcher to adopt. Deleted when the
* engine is deleted.
*/
BurmeseBreakEngine(DictionaryMatcher *adoptDictionary, UErrorCode &status);
/**
* <p>Virtual destructor.</p>
*/
virtual ~BurmeseBreakEngine();
protected:
/**
* <p>Divide up a range of known dictionary characters.</p>
*
* @param text A UText representing the text
* @param rangeStart The start of the range of dictionary characters
* @param rangeEnd The end of the range of dictionary characters
* @param foundBreaks Output of C array of int32_t break positions, or 0
* @param status Information on any errors encountered.
* @return The number of breaks found
*/
virtual int32_t divideUpDictionaryRange( UText *text,
int32_t rangeStart,
int32_t rangeEnd,
UVector32 &foundBreaks,
UErrorCode& status) const;
};
/*******************************************************************
* KhmerBreakEngine
*/
/**
* <p>KhmerBreakEngine is a kind of DictionaryBreakEngine that uses a
* DictionaryMatcher and heuristics to determine Khmer-specific breaks.</p>
*
* <p>After it is constructed a KhmerBreakEngine may be shared between
* threads without synchronization.</p>
*/
class KhmerBreakEngine : public DictionaryBreakEngine {
private:
/**
* The set of characters handled by this engine
* @internal
*/
UnicodeSet fKhmerWordSet;
UnicodeSet fEndWordSet;
UnicodeSet fBeginWordSet;
UnicodeSet fMarkSet;
DictionaryMatcher *fDictionary;
public:
/**
* <p>Default constructor.</p>
*
* @param adoptDictionary A DictionaryMatcher to adopt. Deleted when the
* engine is deleted.
*/
KhmerBreakEngine(DictionaryMatcher *adoptDictionary, UErrorCode &status);
/**
* <p>Virtual destructor.</p>
*/
virtual ~KhmerBreakEngine();
protected:
/**
* <p>Divide up a range of known dictionary characters.</p>
*
* @param text A UText representing the text
* @param rangeStart The start of the range of dictionary characters
* @param rangeEnd The end of the range of dictionary characters
* @param foundBreaks Output of C array of int32_t break positions, or 0
* @param status Information on any errors encountered.
* @return The number of breaks found
*/
virtual int32_t divideUpDictionaryRange( UText *text,
int32_t rangeStart,
int32_t rangeEnd,
UVector32 &foundBreaks,
UErrorCode& status) const;
};
/*******************************************************************
* BurmeseBreakEngine
*/
/**
* <p>BurmeseBreakEngine is a kind of DictionaryBreakEngine that uses a
* DictionaryMatcher and heuristics to determine Burmese-specific breaks.</p>
*
* <p>After it is constructed a BurmeseBreakEngine may be shared between
* threads without synchronization.</p>
*/
class BurmeseBreakEngine : public DictionaryBreakEngine {
private:
/**
* The set of characters handled by this engine
* @internal
*/
UnicodeSet fBurmeseWordSet;
UnicodeSet fEndWordSet;
UnicodeSet fBeginWordSet;
UnicodeSet fMarkSet;
DictionaryMatcher *fDictionary;
public:
/**
* <p>Default constructor.</p>
*
* @param adoptDictionary A DictionaryMatcher to adopt. Deleted when the
* engine is deleted.
*/
BurmeseBreakEngine(DictionaryMatcher *adoptDictionary, UErrorCode &status);
/**
* <p>Virtual destructor.</p>
*/
virtual ~BurmeseBreakEngine();
protected:
/**
* <p>Divide up a range of known dictionary characters.</p>
*
* @param text A UText representing the text
* @param rangeStart The start of the range of dictionary characters
* @param rangeEnd The end of the range of dictionary characters
* @param foundBreaks Output of C array of int32_t break positions, or 0
* @return The number of breaks found
*/
virtual int32_t divideUpDictionaryRange( UText *text,
int32_t rangeStart,
int32_t rangeEnd,
UVector32 &foundBreaks ) const;
};
/*******************************************************************
* KhmerBreakEngine
*/
/**
* <p>KhmerBreakEngine is a kind of DictionaryBreakEngine that uses a
* DictionaryMatcher and heuristics to determine Khmer-specific breaks.</p>
*
* <p>After it is constructed a KhmerBreakEngine may be shared between
* threads without synchronization.</p>
*/
class KhmerBreakEngine : public DictionaryBreakEngine {
private:
/**
* The set of characters handled by this engine
* @internal
*/
UnicodeSet fKhmerWordSet;
UnicodeSet fEndWordSet;
UnicodeSet fBeginWordSet;
UnicodeSet fMarkSet;
DictionaryMatcher *fDictionary;
public:
/**
* <p>Default constructor.</p>
*
* @param adoptDictionary A DictionaryMatcher to adopt. Deleted when the
* engine is deleted.
*/
KhmerBreakEngine(DictionaryMatcher *adoptDictionary, UErrorCode &status);
/**
* <p>Virtual destructor.</p>
*/
virtual ~KhmerBreakEngine();
protected:
/**
* <p>Divide up a range of known dictionary characters.</p>
*
* @param text A UText representing the text
* @param rangeStart The start of the range of dictionary characters
* @param rangeEnd The end of the range of dictionary characters
* @param foundBreaks Output of C array of int32_t break positions, or 0
* @return The number of breaks found
*/
virtual int32_t divideUpDictionaryRange( UText *text,
int32_t rangeStart,
int32_t rangeEnd,
UVector32 &foundBreaks ) const;
};
#if !UCONFIG_NO_NORMALIZATION
/*******************************************************************
@ -385,12 +397,14 @@ class CjkBreakEngine : public DictionaryBreakEngine {
* @param rangeStart The start of the range of dictionary characters
* @param rangeEnd The end of the range of dictionary characters
* @param foundBreaks Output of C array of int32_t break positions, or 0
* @param status Information on any errors encountered.
* @return The number of breaks found
*/
virtual int32_t divideUpDictionaryRange( UText *text,
int32_t rangeStart,
int32_t rangeEnd,
UVector32 &foundBreaks ) const;
UVector32 &foundBreaks,
UErrorCode& status) const;
};

View file

@ -162,10 +162,16 @@ ConstArray2D::~ConstArray2D()
class Array1D : public ReadArray1D {
public:
Array1D() : memory_(nullptr), data_(nullptr), d1_(0) {}
Array1D(int32_t d1)
Array1D(int32_t d1, UErrorCode &status)
: memory_(uprv_malloc(d1 * sizeof(float))),
data_((float*)memory_), d1_(d1) {
clear();
if (U_SUCCESS(status)) {
if (memory_ == nullptr) {
status = U_MEMORY_ALLOCATION_ERROR;
return;
}
clear();
}
}
virtual ~Array1D();
@ -278,10 +284,16 @@ Array1D::~Array1D()
class Array2D : public ReadArray2D {
public:
Array2D() : memory_(nullptr), data_(nullptr), d1_(0), d2_(0) {}
Array2D(int32_t d1, int32_t d2)
Array2D(int32_t d1, int32_t d2, UErrorCode &status)
: memory_(uprv_malloc(d1 * d2 * sizeof(float))),
data_((float*)memory_), d1_(d1), d2_(d2) {
clear();
if (U_SUCCESS(status)) {
if (memory_ == nullptr) {
status = U_MEMORY_ALLOCATION_ERROR;
return;
}
clear();
}
}
virtual ~Array2D();
@ -366,8 +378,10 @@ LSTMData::LSTMData(UResourceBundle* rb, UErrorCode &status)
int32_t embedding_size = ures_getInt(embeddings_res.getAlias(), &status);
LocalUResourceBundlePointer hunits_res(
ures_getByKey(rb, "hunits", nullptr, &status));
if (U_FAILURE(status)) return;
int32_t hunits = ures_getInt(hunits_res.getAlias(), &status);
const UChar* type = ures_getStringByKey(rb, "type", nullptr, &status);
if (U_FAILURE(status)) return;
if (u_strCompare(type, -1, u"codepoints", -1, false) == 0) {
fType = CODE_POINTS;
} else if (u_strCompare(type, -1, u"graphclust", -1, false) == 0) {
@ -375,15 +389,15 @@ LSTMData::LSTMData(UResourceBundle* rb, UErrorCode &status)
}
fName = ures_getStringByKey(rb, "model", nullptr, &status);
fDataRes = ures_getByKey(rb, "data", nullptr, &status);
if (U_FAILURE(status)) return;
int32_t data_len = 0;
const int32_t* data = ures_getIntVector(fDataRes, &data_len, &status);
if (U_FAILURE(status)) return;
LocalUResourceBundlePointer fDictRes(
ures_getByKey(rb, "dict", nullptr, &status));
int32_t num_index = ures_getSize(fDictRes.getAlias());
fDict = uhash_open(uhash_hashUChars, uhash_compareUChars, nullptr, &status);
if (U_FAILURE(status)) {
return;
}
if (U_FAILURE(status)) return;
ures_resetIterator(fDictRes.getAlias());
int32_t idx = 0;
@ -391,10 +405,9 @@ LSTMData::LSTMData(UResourceBundle* rb, UErrorCode &status)
while(ures_hasNext(fDictRes.getAlias())) {
const char *tempKey = nullptr;
const UChar* str = ures_getNextString(fDictRes.getAlias(), nullptr, &tempKey, &status);
if (U_FAILURE(status)) return;
uhash_putiAllowZero(fDict, (void*)str, idx++, &status);
if (U_FAILURE(status)) {
return;
}
if (U_FAILURE(status)) return;
#ifdef LSTM_VECTORIZER_DEBUG
printf("Assign [");
while (*str != 0x0000) {
@ -495,6 +508,7 @@ void CodePointsVectorizer::vectorize(
{
if (offsets.ensureCapacity(endPos - startPos, status) &&
indices.ensureCapacity(endPos - startPos, status)) {
if (U_FAILURE(status)) return;
utext_setNativeIndex(text, startPos);
int32_t current;
UChar str[2] = {0, 0};
@ -533,21 +547,16 @@ void GraphemeClusterVectorizer::vectorize(
UText *text, int32_t startPos, int32_t endPos,
UVector32 &offsets, UVector32 &indices, UErrorCode &status) const
{
if (U_FAILURE(status)) {
return;
}
if (U_FAILURE(status)) return;
if (!offsets.ensureCapacity(endPos - startPos, status) ||
!indices.ensureCapacity(endPos - startPos, status)) {
return;
}
if (U_FAILURE(status)) return;
LocalPointer<BreakIterator> graphemeIter(BreakIterator::createCharacterInstance(Locale(), status));
if (U_FAILURE(status)) {
return;
}
if (U_FAILURE(status)) return;
graphemeIter->setText(text, status);
if (U_FAILURE(status)) {
return;
}
if (U_FAILURE(status)) return;
if (startPos != 0) {
graphemeIter->preceding(startPos);
@ -561,11 +570,10 @@ void GraphemeClusterVectorizer::vectorize(
}
if (current > startPos) {
utext_extract(text, last, current, str, MAX_GRAPHEME_CLSTER_LENTH, &status);
if (U_FAILURE(status)) {
break;
}
if (U_FAILURE(status)) return;
offsets.addElement(last, status);
indices.addElement(stringToIndex(str), status);
if (U_FAILURE(status)) return;
}
last = current;
}
@ -583,13 +591,19 @@ void GraphemeClusterVectorizer::vectorize(
// https://en.wikipedia.org/wiki/Long_short-term_memory#LSTM_with_a_forget_gate
void compute(
const ReadArray2D& W, const ReadArray2D& U, const ReadArray1D& b,
const ReadArray1D& x, Array1D& h, Array1D& c)
const ReadArray1D& x, Array1D& h, Array1D& c,
UErrorCode &status)
{
if (U_FAILURE(status)) return;
// ifco = x * W + h * U + b
Array1D ifco(b.d1());
ifco.dotProduct(x, W)
.add(Array1D(b.d1()).dotProduct(h, U))
.add(b);
Array1D ifco(b.d1(), status);
{
Array1D hU(b.d1(), status);
if (U_FAILURE(status)) return;
ifco.dotProduct(x, W)
.add(hU.dotProduct(h, U))
.add(b);
} // delocate hU
int32_t hunits = b.d1() / 4;
ifco.slice(0*hunits, hunits).sigmoid(); // i: sigmod
@ -597,10 +611,14 @@ void compute(
ifco.slice(2*hunits, hunits).tanh(); // c_: tanh
ifco.slice(3*hunits, hunits).sigmoid(); // o: sigmod
c.hadamardProduct(ifco.slice(hunits, hunits))
.add(Array1D(c.d1())
.assign(ifco.slice(0, hunits))
.hadamardProduct(ifco.slice(2*hunits, hunits)));
{
Array1D ic(c.d1(), status);
if (U_FAILURE(status)) return;
c.hadamardProduct(ifco.slice(hunits, hunits))
.add(ic
.assign(ifco.slice(0, hunits))
.hadamardProduct(ifco.slice(2*hunits, hunits)));
}
h.assign(c)
.tanh()
@ -614,10 +632,12 @@ static const int32_t MIN_WORD = 2;
static const int32_t MIN_WORD_SPAN = MIN_WORD * 2;
int32_t
LSTMBreakEngine::divideUpDictionaryRange(UText *text,
int32_t startPos,
int32_t endPos,
UVector32 &foundBreaks) const {
LSTMBreakEngine::divideUpDictionaryRange( UText *text,
int32_t startPos,
int32_t endPos,
UVector32 &foundBreaks,
UErrorCode& status) const {
if (U_FAILURE(status)) return 0;
int32_t beginFoundBreakSize = foundBreaks.size();
utext_setNativeIndex(text, startPos);
utext_moveIndex32(text, MIN_WORD_SPAN);
@ -625,11 +645,12 @@ LSTMBreakEngine::divideUpDictionaryRange(UText *text,
return 0; // Not enough characters for two words
}
utext_setNativeIndex(text, startPos);
UErrorCode status = U_ZERO_ERROR;
UVector32 offsets(status);
UVector32 indices(status);
if (U_FAILURE(status)) return 0;
fVectorizer->vectorize(text, startPos, endPos, offsets, indices, status);
if (U_FAILURE(status)) return 0;
int32_t* offsetsBuf = offsets.getBuffer();
int32_t* indicesBuf = indices.getBuffer();
@ -640,12 +661,13 @@ LSTMBreakEngine::divideUpDictionaryRange(UText *text,
// Python or ICU4X implementation. We first perform the Backward LSTM
// and then merge the iteration of the forward LSTM and the output layer
// together because we only neetdto remember the h[t-1] for Forward LSTM.
Array1D c(hunits);
Array1D c(hunits, status);
// TODO: limit size of hBackward. If input_seq_len is too big, we could
// run out of memory.
// Backward LSTM
Array2D hBackward(input_seq_len, hunits);
Array2D hBackward(input_seq_len, hunits, status);
if (U_FAILURE(status)) return 0;
for (int32_t i = input_seq_len - 1; i >= 0; i--) {
Array1D hRow = hBackward.row(i);
if (i != input_seq_len - 1) {
@ -660,13 +682,15 @@ LSTMBreakEngine::divideUpDictionaryRange(UText *text,
#endif // LSTM_DEBUG
compute(fData->fBackwardW, fData->fBackwardU, fData->fBackwardB,
fData->fEmbedding.row(indicesBuf[i]),
hRow, c);
hRow, c, status);
if (U_FAILURE(status)) return 0;
}
Array1D logp(4);
Array1D logp(4, status);
// Allocate fbRow and slice the internal array in two.
Array1D fbRow(2 * hunits);
Array1D fbRow(2 * hunits, status);
if (U_FAILURE(status)) return 0;
Array1D forwardRow = fbRow.slice(0, hunits); // point to first half of data in fbRow.
Array1D backwardRow = fbRow.slice(hunits, hunits); // point to second half of data n fbRow.
@ -683,7 +707,8 @@ LSTMBreakEngine::divideUpDictionaryRange(UText *text,
// of fbRow.
compute(fData->fForwardW, fData->fForwardU, fData->fForwardB,
fData->fEmbedding.row(indicesBuf[i]),
forwardRow, c);
forwardRow, c, status);
if (U_FAILURE(status)) return 0;
// assign the data from hBackward.row(i) to second half of fbRowa.
backwardRow.assign(hBackward.row(i));
@ -702,6 +727,7 @@ LSTMBreakEngine::divideUpDictionaryRange(UText *text,
if (current == BEGIN || current == SINGLE) {
if (i != 0) {
foundBreaks.addElement(offsetsBuf[i], status);
if (U_FAILURE(status)) return 0;
}
}
}
@ -759,14 +785,13 @@ U_CAPI const LSTMData* U_EXPORT2 CreateLSTMDataForScript(UScriptCode script, UEr
return nullptr;
}
UnicodeString name = defaultLSTM(script, status);
if (U_FAILURE(status)) return nullptr;
CharString namebuf;
namebuf.appendInvariantChars(name, status).truncate(namebuf.lastIndexOf('.'));
LocalUResourceBundlePointer rb(
ures_openDirect(U_ICUDATA_BRKITR, namebuf.data(), &status));
if (U_FAILURE(status)) {
return nullptr;
}
if (U_FAILURE(status)) return nullptr;
return CreateLSTMData(rb.getAlias(), status);
}

View file

@ -55,12 +55,14 @@ protected:
* @param rangeStart The start of the range of dictionary characters
* @param rangeEnd The end of the range of dictionary characters
* @param foundBreaks Output of C array of int32_t break positions, or 0
* @param status Information on any errors encountered.
* @return The number of breaks found
*/
virtual int32_t divideUpDictionaryRange(UText *text,
int32_t rangeStart,
int32_t rangeEnd,
UVector32 &foundBreaks ) const;
UVector32 &foundBreaks,
UErrorCode& status) const;
private:
const LSTMData* fData;
const Vectorizer* fVectorizer;

View file

@ -163,7 +163,7 @@ void RuleBasedBreakIterator::DictionaryCache::populateDictionary(int32_t startPo
// Ask the language object if there are any breaks. It will add them to the cache and
// leave the text pointer on the other side of its range, ready to search for the next one.
if (lbe != NULL) {
foundBreakCount += lbe->findBreaks(text, rangeStart, rangeEnd, fBreaks);
foundBreakCount += lbe->findBreaks(text, rangeStart, rangeEnd, fBreaks, status);
}
// Reload the loop variables for the next go-round

View file

@ -151,7 +151,11 @@ void LSTMBETest::runTestFromFile(const char* filename) {
dataerrln("%s:%d Error %s Could not allocate UVextor32", __FILE__, __LINE__, u_errorName(status));
return;
}
engine->findBreaks(&ut, 0, value.length(), actual);
engine->findBreaks(&ut, 0, value.length(), actual, status);
if (U_FAILURE(status)) {
dataerrln("%s:%d Error %s findBreaks failed", __FILE__, __LINE__, u_errorName(status));
return;
}
utext_close(&ut);
for (int32_t i = 0; i < actual.size(); i++) {
ss << actual.elementAti(i) << ", ";