mirror of
https://github.com/unicode-org/icu.git
synced 2025-04-06 14:05:32 +00:00
ICU-21569 Propagate LSTM memory allocation issue
This commit is contained in:
parent
512290fd23
commit
4136fa207f
8 changed files with 243 additions and 187 deletions
|
@ -78,7 +78,9 @@ int32_t
|
|||
UnhandledEngine::findBreaks( UText *text,
|
||||
int32_t /* startPos */,
|
||||
int32_t endPos,
|
||||
UVector32 &/*foundBreaks*/ ) const {
|
||||
UVector32 &/*foundBreaks*/,
|
||||
UErrorCode &status) const {
|
||||
if (U_FAILURE(status)) return 0;
|
||||
UChar32 c = utext_current32(text);
|
||||
while((int32_t)utext_getNativeIndex(text) < endPos && fHandled->contains(c)) {
|
||||
utext_next32(text); // TODO: recast loop to work with post-increment operations.
|
||||
|
|
|
@ -68,12 +68,14 @@ class LanguageBreakEngine : public UMemory {
|
|||
* @param startPos The start of the run within the supplied text.
|
||||
* @param endPos The end of the run within the supplied text.
|
||||
* @param foundBreaks A Vector of int32_t to receive the breaks.
|
||||
* @param status Information on any errors encountered.
|
||||
* @return The number of breaks found.
|
||||
*/
|
||||
virtual int32_t findBreaks( UText *text,
|
||||
int32_t startPos,
|
||||
int32_t endPos,
|
||||
UVector32 &foundBreaks ) const = 0;
|
||||
UVector32 &foundBreaks,
|
||||
UErrorCode &status) const = 0;
|
||||
|
||||
};
|
||||
|
||||
|
@ -185,12 +187,14 @@ class UnhandledEngine : public LanguageBreakEngine {
|
|||
* @param startPos The start of the run within the supplied text.
|
||||
* @param endPos The end of the run within the supplied text.
|
||||
* @param foundBreaks An allocated C array of the breaks found, if any
|
||||
* @param status Information on any errors encountered.
|
||||
* @return The number of breaks found.
|
||||
*/
|
||||
virtual int32_t findBreaks( UText *text,
|
||||
int32_t startPos,
|
||||
int32_t endPos,
|
||||
UVector32 &foundBreaks ) const;
|
||||
UVector32 &foundBreaks,
|
||||
UErrorCode &status) const;
|
||||
|
||||
/**
|
||||
* <p>Tell the engine to handle a particular character and break type.</p>
|
||||
|
|
|
@ -47,7 +47,9 @@ int32_t
|
|||
DictionaryBreakEngine::findBreaks( UText *text,
|
||||
int32_t startPos,
|
||||
int32_t endPos,
|
||||
UVector32 &foundBreaks ) const {
|
||||
UVector32 &foundBreaks,
|
||||
UErrorCode& status) const {
|
||||
if (U_FAILURE(status)) return 0;
|
||||
(void)startPos; // TODO: remove this param?
|
||||
int32_t result = 0;
|
||||
|
||||
|
@ -66,7 +68,7 @@ DictionaryBreakEngine::findBreaks( UText *text,
|
|||
}
|
||||
rangeStart = start;
|
||||
rangeEnd = current;
|
||||
result = divideUpDictionaryRange(text, rangeStart, rangeEnd, foundBreaks);
|
||||
result = divideUpDictionaryRange(text, rangeStart, rangeEnd, foundBreaks, status);
|
||||
utext_setNativeIndex(text, current);
|
||||
|
||||
return result;
|
||||
|
@ -227,7 +229,9 @@ int32_t
|
|||
ThaiBreakEngine::divideUpDictionaryRange( UText *text,
|
||||
int32_t rangeStart,
|
||||
int32_t rangeEnd,
|
||||
UVector32 &foundBreaks ) const {
|
||||
UVector32 &foundBreaks,
|
||||
UErrorCode& status) const {
|
||||
if (U_FAILURE(status)) return 0;
|
||||
utext_setNativeIndex(text, rangeStart);
|
||||
utext_moveIndex32(text, THAI_MIN_WORD_SPAN);
|
||||
if (utext_getNativeIndex(text) >= rangeEnd) {
|
||||
|
@ -240,7 +244,6 @@ ThaiBreakEngine::divideUpDictionaryRange( UText *text,
|
|||
int32_t cpWordLength = 0; // Word Length in Code Points.
|
||||
int32_t cuWordLength = 0; // Word length in code units (UText native indexing)
|
||||
int32_t current;
|
||||
UErrorCode status = U_ZERO_ERROR;
|
||||
PossibleWord words[THAI_LOOKAHEAD];
|
||||
|
||||
utext_setNativeIndex(text, rangeStart);
|
||||
|
@ -465,7 +468,9 @@ int32_t
|
|||
LaoBreakEngine::divideUpDictionaryRange( UText *text,
|
||||
int32_t rangeStart,
|
||||
int32_t rangeEnd,
|
||||
UVector32 &foundBreaks ) const {
|
||||
UVector32 &foundBreaks,
|
||||
UErrorCode& status) const {
|
||||
if (U_FAILURE(status)) return 0;
|
||||
if ((rangeEnd - rangeStart) < LAO_MIN_WORD_SPAN) {
|
||||
return 0; // Not enough characters for two words
|
||||
}
|
||||
|
@ -474,11 +479,10 @@ LaoBreakEngine::divideUpDictionaryRange( UText *text,
|
|||
int32_t cpWordLength = 0;
|
||||
int32_t cuWordLength = 0;
|
||||
int32_t current;
|
||||
UErrorCode status = U_ZERO_ERROR;
|
||||
PossibleWord words[LAO_LOOKAHEAD];
|
||||
|
||||
|
||||
utext_setNativeIndex(text, rangeStart);
|
||||
|
||||
|
||||
while (U_SUCCESS(status) && (current = (int32_t)utext_getNativeIndex(text)) < rangeEnd) {
|
||||
cuWordLength = 0;
|
||||
cpWordLength = 0;
|
||||
|
@ -657,7 +661,9 @@ int32_t
|
|||
BurmeseBreakEngine::divideUpDictionaryRange( UText *text,
|
||||
int32_t rangeStart,
|
||||
int32_t rangeEnd,
|
||||
UVector32 &foundBreaks ) const {
|
||||
UVector32 &foundBreaks,
|
||||
UErrorCode& status ) const {
|
||||
if (U_FAILURE(status)) return 0;
|
||||
if ((rangeEnd - rangeStart) < BURMESE_MIN_WORD_SPAN) {
|
||||
return 0; // Not enough characters for two words
|
||||
}
|
||||
|
@ -666,11 +672,10 @@ BurmeseBreakEngine::divideUpDictionaryRange( UText *text,
|
|||
int32_t cpWordLength = 0;
|
||||
int32_t cuWordLength = 0;
|
||||
int32_t current;
|
||||
UErrorCode status = U_ZERO_ERROR;
|
||||
PossibleWord words[BURMESE_LOOKAHEAD];
|
||||
|
||||
|
||||
utext_setNativeIndex(text, rangeStart);
|
||||
|
||||
|
||||
while (U_SUCCESS(status) && (current = (int32_t)utext_getNativeIndex(text)) < rangeEnd) {
|
||||
cuWordLength = 0;
|
||||
cpWordLength = 0;
|
||||
|
@ -861,7 +866,9 @@ int32_t
|
|||
KhmerBreakEngine::divideUpDictionaryRange( UText *text,
|
||||
int32_t rangeStart,
|
||||
int32_t rangeEnd,
|
||||
UVector32 &foundBreaks ) const {
|
||||
UVector32 &foundBreaks,
|
||||
UErrorCode& status ) const {
|
||||
if (U_FAILURE(status)) return 0;
|
||||
if ((rangeEnd - rangeStart) < KHMER_MIN_WORD_SPAN) {
|
||||
return 0; // Not enough characters for two words
|
||||
}
|
||||
|
@ -870,7 +877,6 @@ KhmerBreakEngine::divideUpDictionaryRange( UText *text,
|
|||
int32_t cpWordLength = 0;
|
||||
int32_t cuWordLength = 0;
|
||||
int32_t current;
|
||||
UErrorCode status = U_ZERO_ERROR;
|
||||
PossibleWord words[KHMER_LOOKAHEAD];
|
||||
|
||||
utext_setNativeIndex(text, rangeStart);
|
||||
|
@ -1110,7 +1116,9 @@ int32_t
|
|||
CjkBreakEngine::divideUpDictionaryRange( UText *inText,
|
||||
int32_t rangeStart,
|
||||
int32_t rangeEnd,
|
||||
UVector32 &foundBreaks ) const {
|
||||
UVector32 &foundBreaks,
|
||||
UErrorCode& status) const {
|
||||
if (U_FAILURE(status)) return 0;
|
||||
if (rangeStart >= rangeEnd) {
|
||||
return 0;
|
||||
}
|
||||
|
@ -1122,9 +1130,6 @@ CjkBreakEngine::divideUpDictionaryRange( UText *inText,
|
|||
// If NULL then mapping is 1:1
|
||||
LocalPointer<UVector32> inputMap;
|
||||
|
||||
UErrorCode status = U_ZERO_ERROR;
|
||||
|
||||
|
||||
// if UText has the input string as one contiguous UTF-16 chunk
|
||||
if ((inText->providerProperties & utext_i32_flag(UTEXT_PROVIDER_STABLE_CHUNKS)) &&
|
||||
inText->chunkNativeStart <= rangeStart &&
|
||||
|
|
|
@ -68,17 +68,19 @@ class DictionaryBreakEngine : public LanguageBreakEngine {
|
|||
* <p>Find any breaks within a run in the supplied text.</p>
|
||||
*
|
||||
* @param text A UText representing the text. The iterator is left at
|
||||
* the end of the run of characters which the engine is capable of handling
|
||||
* the end of the run of characters which the engine is capable of handling
|
||||
* that starts from the first character in the range.
|
||||
* @param startPos The start of the run within the supplied text.
|
||||
* @param endPos The end of the run within the supplied text.
|
||||
* @param foundBreaks vector of int32_t to receive the break positions
|
||||
* @param status Information on any errors encountered.
|
||||
* @return The number of breaks found.
|
||||
*/
|
||||
virtual int32_t findBreaks( UText *text,
|
||||
int32_t startPos,
|
||||
int32_t endPos,
|
||||
UVector32 &foundBreaks ) const;
|
||||
UVector32 &foundBreaks,
|
||||
UErrorCode& status ) const;
|
||||
|
||||
protected:
|
||||
|
||||
|
@ -96,12 +98,14 @@ class DictionaryBreakEngine : public LanguageBreakEngine {
|
|||
* @param rangeStart The start of the range of dictionary characters
|
||||
* @param rangeEnd The end of the range of dictionary characters
|
||||
* @param foundBreaks Output of C array of int32_t break positions, or 0
|
||||
* @param status Information on any errors encountered.
|
||||
* @return The number of breaks found
|
||||
*/
|
||||
virtual int32_t divideUpDictionaryRange( UText *text,
|
||||
int32_t rangeStart,
|
||||
int32_t rangeEnd,
|
||||
UVector32 &foundBreaks ) const = 0;
|
||||
UVector32 &foundBreaks,
|
||||
UErrorCode& status) const = 0;
|
||||
|
||||
};
|
||||
|
||||
|
@ -153,12 +157,14 @@ class ThaiBreakEngine : public DictionaryBreakEngine {
|
|||
* @param rangeStart The start of the range of dictionary characters
|
||||
* @param rangeEnd The end of the range of dictionary characters
|
||||
* @param foundBreaks Output of C array of int32_t break positions, or 0
|
||||
* @param status Information on any errors encountered.
|
||||
* @return The number of breaks found
|
||||
*/
|
||||
virtual int32_t divideUpDictionaryRange( UText *text,
|
||||
int32_t rangeStart,
|
||||
int32_t rangeEnd,
|
||||
UVector32 &foundBreaks ) const;
|
||||
UVector32 &foundBreaks,
|
||||
UErrorCode& status) const;
|
||||
|
||||
};
|
||||
|
||||
|
@ -209,127 +215,133 @@ class LaoBreakEngine : public DictionaryBreakEngine {
|
|||
* @param rangeStart The start of the range of dictionary characters
|
||||
* @param rangeEnd The end of the range of dictionary characters
|
||||
* @param foundBreaks Output of C array of int32_t break positions, or 0
|
||||
* @param status Information on any errors encountered.
|
||||
* @return The number of breaks found
|
||||
*/
|
||||
virtual int32_t divideUpDictionaryRange( UText *text,
|
||||
int32_t rangeStart,
|
||||
int32_t rangeEnd,
|
||||
UVector32 &foundBreaks ) const;
|
||||
UVector32 &foundBreaks,
|
||||
UErrorCode& status) const;
|
||||
|
||||
};
|
||||
|
||||
/*******************************************************************
|
||||
* BurmeseBreakEngine
|
||||
*/
|
||||
|
||||
/**
|
||||
* <p>BurmeseBreakEngine is a kind of DictionaryBreakEngine that uses a
|
||||
* DictionaryMatcher and heuristics to determine Burmese-specific breaks.</p>
|
||||
*
|
||||
* <p>After it is constructed a BurmeseBreakEngine may be shared between
|
||||
* threads without synchronization.</p>
|
||||
*/
|
||||
class BurmeseBreakEngine : public DictionaryBreakEngine {
|
||||
private:
|
||||
/**
|
||||
* The set of characters handled by this engine
|
||||
* @internal
|
||||
*/
|
||||
|
||||
UnicodeSet fBurmeseWordSet;
|
||||
UnicodeSet fEndWordSet;
|
||||
UnicodeSet fBeginWordSet;
|
||||
UnicodeSet fMarkSet;
|
||||
DictionaryMatcher *fDictionary;
|
||||
|
||||
public:
|
||||
|
||||
/**
|
||||
* <p>Default constructor.</p>
|
||||
*
|
||||
* @param adoptDictionary A DictionaryMatcher to adopt. Deleted when the
|
||||
* engine is deleted.
|
||||
*/
|
||||
BurmeseBreakEngine(DictionaryMatcher *adoptDictionary, UErrorCode &status);
|
||||
|
||||
/**
|
||||
* <p>Virtual destructor.</p>
|
||||
*/
|
||||
virtual ~BurmeseBreakEngine();
|
||||
|
||||
protected:
|
||||
/**
|
||||
* <p>Divide up a range of known dictionary characters.</p>
|
||||
*
|
||||
* @param text A UText representing the text
|
||||
* @param rangeStart The start of the range of dictionary characters
|
||||
* @param rangeEnd The end of the range of dictionary characters
|
||||
* @param foundBreaks Output of C array of int32_t break positions, or 0
|
||||
* @param status Information on any errors encountered.
|
||||
* @return The number of breaks found
|
||||
*/
|
||||
virtual int32_t divideUpDictionaryRange( UText *text,
|
||||
int32_t rangeStart,
|
||||
int32_t rangeEnd,
|
||||
UVector32 &foundBreaks,
|
||||
UErrorCode& status) const;
|
||||
|
||||
};
|
||||
|
||||
/*******************************************************************
|
||||
* KhmerBreakEngine
|
||||
*/
|
||||
|
||||
/**
|
||||
* <p>KhmerBreakEngine is a kind of DictionaryBreakEngine that uses a
|
||||
* DictionaryMatcher and heuristics to determine Khmer-specific breaks.</p>
|
||||
*
|
||||
* <p>After it is constructed a KhmerBreakEngine may be shared between
|
||||
* threads without synchronization.</p>
|
||||
*/
|
||||
class KhmerBreakEngine : public DictionaryBreakEngine {
|
||||
private:
|
||||
/**
|
||||
* The set of characters handled by this engine
|
||||
* @internal
|
||||
*/
|
||||
|
||||
UnicodeSet fKhmerWordSet;
|
||||
UnicodeSet fEndWordSet;
|
||||
UnicodeSet fBeginWordSet;
|
||||
UnicodeSet fMarkSet;
|
||||
DictionaryMatcher *fDictionary;
|
||||
|
||||
public:
|
||||
|
||||
/**
|
||||
* <p>Default constructor.</p>
|
||||
*
|
||||
* @param adoptDictionary A DictionaryMatcher to adopt. Deleted when the
|
||||
* engine is deleted.
|
||||
*/
|
||||
KhmerBreakEngine(DictionaryMatcher *adoptDictionary, UErrorCode &status);
|
||||
|
||||
/**
|
||||
* <p>Virtual destructor.</p>
|
||||
*/
|
||||
virtual ~KhmerBreakEngine();
|
||||
|
||||
protected:
|
||||
/**
|
||||
* <p>Divide up a range of known dictionary characters.</p>
|
||||
*
|
||||
* @param text A UText representing the text
|
||||
* @param rangeStart The start of the range of dictionary characters
|
||||
* @param rangeEnd The end of the range of dictionary characters
|
||||
* @param foundBreaks Output of C array of int32_t break positions, or 0
|
||||
* @param status Information on any errors encountered.
|
||||
* @return The number of breaks found
|
||||
*/
|
||||
virtual int32_t divideUpDictionaryRange( UText *text,
|
||||
int32_t rangeStart,
|
||||
int32_t rangeEnd,
|
||||
UVector32 &foundBreaks,
|
||||
UErrorCode& status) const;
|
||||
|
||||
};
|
||||
|
||||
/*******************************************************************
|
||||
* BurmeseBreakEngine
|
||||
*/
|
||||
|
||||
/**
|
||||
* <p>BurmeseBreakEngine is a kind of DictionaryBreakEngine that uses a
|
||||
* DictionaryMatcher and heuristics to determine Burmese-specific breaks.</p>
|
||||
*
|
||||
* <p>After it is constructed a BurmeseBreakEngine may be shared between
|
||||
* threads without synchronization.</p>
|
||||
*/
|
||||
class BurmeseBreakEngine : public DictionaryBreakEngine {
|
||||
private:
|
||||
/**
|
||||
* The set of characters handled by this engine
|
||||
* @internal
|
||||
*/
|
||||
|
||||
UnicodeSet fBurmeseWordSet;
|
||||
UnicodeSet fEndWordSet;
|
||||
UnicodeSet fBeginWordSet;
|
||||
UnicodeSet fMarkSet;
|
||||
DictionaryMatcher *fDictionary;
|
||||
|
||||
public:
|
||||
|
||||
/**
|
||||
* <p>Default constructor.</p>
|
||||
*
|
||||
* @param adoptDictionary A DictionaryMatcher to adopt. Deleted when the
|
||||
* engine is deleted.
|
||||
*/
|
||||
BurmeseBreakEngine(DictionaryMatcher *adoptDictionary, UErrorCode &status);
|
||||
|
||||
/**
|
||||
* <p>Virtual destructor.</p>
|
||||
*/
|
||||
virtual ~BurmeseBreakEngine();
|
||||
|
||||
protected:
|
||||
/**
|
||||
* <p>Divide up a range of known dictionary characters.</p>
|
||||
*
|
||||
* @param text A UText representing the text
|
||||
* @param rangeStart The start of the range of dictionary characters
|
||||
* @param rangeEnd The end of the range of dictionary characters
|
||||
* @param foundBreaks Output of C array of int32_t break positions, or 0
|
||||
* @return The number of breaks found
|
||||
*/
|
||||
virtual int32_t divideUpDictionaryRange( UText *text,
|
||||
int32_t rangeStart,
|
||||
int32_t rangeEnd,
|
||||
UVector32 &foundBreaks ) const;
|
||||
|
||||
};
|
||||
|
||||
/*******************************************************************
|
||||
* KhmerBreakEngine
|
||||
*/
|
||||
|
||||
/**
|
||||
* <p>KhmerBreakEngine is a kind of DictionaryBreakEngine that uses a
|
||||
* DictionaryMatcher and heuristics to determine Khmer-specific breaks.</p>
|
||||
*
|
||||
* <p>After it is constructed a KhmerBreakEngine may be shared between
|
||||
* threads without synchronization.</p>
|
||||
*/
|
||||
class KhmerBreakEngine : public DictionaryBreakEngine {
|
||||
private:
|
||||
/**
|
||||
* The set of characters handled by this engine
|
||||
* @internal
|
||||
*/
|
||||
|
||||
UnicodeSet fKhmerWordSet;
|
||||
UnicodeSet fEndWordSet;
|
||||
UnicodeSet fBeginWordSet;
|
||||
UnicodeSet fMarkSet;
|
||||
DictionaryMatcher *fDictionary;
|
||||
|
||||
public:
|
||||
|
||||
/**
|
||||
* <p>Default constructor.</p>
|
||||
*
|
||||
* @param adoptDictionary A DictionaryMatcher to adopt. Deleted when the
|
||||
* engine is deleted.
|
||||
*/
|
||||
KhmerBreakEngine(DictionaryMatcher *adoptDictionary, UErrorCode &status);
|
||||
|
||||
/**
|
||||
* <p>Virtual destructor.</p>
|
||||
*/
|
||||
virtual ~KhmerBreakEngine();
|
||||
|
||||
protected:
|
||||
/**
|
||||
* <p>Divide up a range of known dictionary characters.</p>
|
||||
*
|
||||
* @param text A UText representing the text
|
||||
* @param rangeStart The start of the range of dictionary characters
|
||||
* @param rangeEnd The end of the range of dictionary characters
|
||||
* @param foundBreaks Output of C array of int32_t break positions, or 0
|
||||
* @return The number of breaks found
|
||||
*/
|
||||
virtual int32_t divideUpDictionaryRange( UText *text,
|
||||
int32_t rangeStart,
|
||||
int32_t rangeEnd,
|
||||
UVector32 &foundBreaks ) const;
|
||||
|
||||
};
|
||||
|
||||
#if !UCONFIG_NO_NORMALIZATION
|
||||
|
||||
/*******************************************************************
|
||||
|
@ -385,12 +397,14 @@ class CjkBreakEngine : public DictionaryBreakEngine {
|
|||
* @param rangeStart The start of the range of dictionary characters
|
||||
* @param rangeEnd The end of the range of dictionary characters
|
||||
* @param foundBreaks Output of C array of int32_t break positions, or 0
|
||||
* @param status Information on any errors encountered.
|
||||
* @return The number of breaks found
|
||||
*/
|
||||
virtual int32_t divideUpDictionaryRange( UText *text,
|
||||
int32_t rangeStart,
|
||||
int32_t rangeEnd,
|
||||
UVector32 &foundBreaks ) const;
|
||||
UVector32 &foundBreaks,
|
||||
UErrorCode& status) const;
|
||||
|
||||
};
|
||||
|
||||
|
|
|
@ -162,10 +162,16 @@ ConstArray2D::~ConstArray2D()
|
|||
class Array1D : public ReadArray1D {
|
||||
public:
|
||||
Array1D() : memory_(nullptr), data_(nullptr), d1_(0) {}
|
||||
Array1D(int32_t d1)
|
||||
Array1D(int32_t d1, UErrorCode &status)
|
||||
: memory_(uprv_malloc(d1 * sizeof(float))),
|
||||
data_((float*)memory_), d1_(d1) {
|
||||
clear();
|
||||
if (U_SUCCESS(status)) {
|
||||
if (memory_ == nullptr) {
|
||||
status = U_MEMORY_ALLOCATION_ERROR;
|
||||
return;
|
||||
}
|
||||
clear();
|
||||
}
|
||||
}
|
||||
|
||||
virtual ~Array1D();
|
||||
|
@ -278,10 +284,16 @@ Array1D::~Array1D()
|
|||
class Array2D : public ReadArray2D {
|
||||
public:
|
||||
Array2D() : memory_(nullptr), data_(nullptr), d1_(0), d2_(0) {}
|
||||
Array2D(int32_t d1, int32_t d2)
|
||||
Array2D(int32_t d1, int32_t d2, UErrorCode &status)
|
||||
: memory_(uprv_malloc(d1 * d2 * sizeof(float))),
|
||||
data_((float*)memory_), d1_(d1), d2_(d2) {
|
||||
clear();
|
||||
if (U_SUCCESS(status)) {
|
||||
if (memory_ == nullptr) {
|
||||
status = U_MEMORY_ALLOCATION_ERROR;
|
||||
return;
|
||||
}
|
||||
clear();
|
||||
}
|
||||
}
|
||||
virtual ~Array2D();
|
||||
|
||||
|
@ -366,8 +378,10 @@ LSTMData::LSTMData(UResourceBundle* rb, UErrorCode &status)
|
|||
int32_t embedding_size = ures_getInt(embeddings_res.getAlias(), &status);
|
||||
LocalUResourceBundlePointer hunits_res(
|
||||
ures_getByKey(rb, "hunits", nullptr, &status));
|
||||
if (U_FAILURE(status)) return;
|
||||
int32_t hunits = ures_getInt(hunits_res.getAlias(), &status);
|
||||
const UChar* type = ures_getStringByKey(rb, "type", nullptr, &status);
|
||||
if (U_FAILURE(status)) return;
|
||||
if (u_strCompare(type, -1, u"codepoints", -1, false) == 0) {
|
||||
fType = CODE_POINTS;
|
||||
} else if (u_strCompare(type, -1, u"graphclust", -1, false) == 0) {
|
||||
|
@ -375,15 +389,15 @@ LSTMData::LSTMData(UResourceBundle* rb, UErrorCode &status)
|
|||
}
|
||||
fName = ures_getStringByKey(rb, "model", nullptr, &status);
|
||||
fDataRes = ures_getByKey(rb, "data", nullptr, &status);
|
||||
if (U_FAILURE(status)) return;
|
||||
int32_t data_len = 0;
|
||||
const int32_t* data = ures_getIntVector(fDataRes, &data_len, &status);
|
||||
if (U_FAILURE(status)) return;
|
||||
LocalUResourceBundlePointer fDictRes(
|
||||
ures_getByKey(rb, "dict", nullptr, &status));
|
||||
int32_t num_index = ures_getSize(fDictRes.getAlias());
|
||||
fDict = uhash_open(uhash_hashUChars, uhash_compareUChars, nullptr, &status);
|
||||
if (U_FAILURE(status)) {
|
||||
return;
|
||||
}
|
||||
if (U_FAILURE(status)) return;
|
||||
|
||||
ures_resetIterator(fDictRes.getAlias());
|
||||
int32_t idx = 0;
|
||||
|
@ -391,10 +405,9 @@ LSTMData::LSTMData(UResourceBundle* rb, UErrorCode &status)
|
|||
while(ures_hasNext(fDictRes.getAlias())) {
|
||||
const char *tempKey = nullptr;
|
||||
const UChar* str = ures_getNextString(fDictRes.getAlias(), nullptr, &tempKey, &status);
|
||||
if (U_FAILURE(status)) return;
|
||||
uhash_putiAllowZero(fDict, (void*)str, idx++, &status);
|
||||
if (U_FAILURE(status)) {
|
||||
return;
|
||||
}
|
||||
if (U_FAILURE(status)) return;
|
||||
#ifdef LSTM_VECTORIZER_DEBUG
|
||||
printf("Assign [");
|
||||
while (*str != 0x0000) {
|
||||
|
@ -495,6 +508,7 @@ void CodePointsVectorizer::vectorize(
|
|||
{
|
||||
if (offsets.ensureCapacity(endPos - startPos, status) &&
|
||||
indices.ensureCapacity(endPos - startPos, status)) {
|
||||
if (U_FAILURE(status)) return;
|
||||
utext_setNativeIndex(text, startPos);
|
||||
int32_t current;
|
||||
UChar str[2] = {0, 0};
|
||||
|
@ -533,21 +547,16 @@ void GraphemeClusterVectorizer::vectorize(
|
|||
UText *text, int32_t startPos, int32_t endPos,
|
||||
UVector32 &offsets, UVector32 &indices, UErrorCode &status) const
|
||||
{
|
||||
if (U_FAILURE(status)) {
|
||||
return;
|
||||
}
|
||||
if (U_FAILURE(status)) return;
|
||||
if (!offsets.ensureCapacity(endPos - startPos, status) ||
|
||||
!indices.ensureCapacity(endPos - startPos, status)) {
|
||||
return;
|
||||
}
|
||||
if (U_FAILURE(status)) return;
|
||||
LocalPointer<BreakIterator> graphemeIter(BreakIterator::createCharacterInstance(Locale(), status));
|
||||
if (U_FAILURE(status)) {
|
||||
return;
|
||||
}
|
||||
if (U_FAILURE(status)) return;
|
||||
graphemeIter->setText(text, status);
|
||||
if (U_FAILURE(status)) {
|
||||
return;
|
||||
}
|
||||
if (U_FAILURE(status)) return;
|
||||
|
||||
if (startPos != 0) {
|
||||
graphemeIter->preceding(startPos);
|
||||
|
@ -561,11 +570,10 @@ void GraphemeClusterVectorizer::vectorize(
|
|||
}
|
||||
if (current > startPos) {
|
||||
utext_extract(text, last, current, str, MAX_GRAPHEME_CLSTER_LENTH, &status);
|
||||
if (U_FAILURE(status)) {
|
||||
break;
|
||||
}
|
||||
if (U_FAILURE(status)) return;
|
||||
offsets.addElement(last, status);
|
||||
indices.addElement(stringToIndex(str), status);
|
||||
if (U_FAILURE(status)) return;
|
||||
}
|
||||
last = current;
|
||||
}
|
||||
|
@ -583,13 +591,19 @@ void GraphemeClusterVectorizer::vectorize(
|
|||
// https://en.wikipedia.org/wiki/Long_short-term_memory#LSTM_with_a_forget_gate
|
||||
void compute(
|
||||
const ReadArray2D& W, const ReadArray2D& U, const ReadArray1D& b,
|
||||
const ReadArray1D& x, Array1D& h, Array1D& c)
|
||||
const ReadArray1D& x, Array1D& h, Array1D& c,
|
||||
UErrorCode &status)
|
||||
{
|
||||
if (U_FAILURE(status)) return;
|
||||
// ifco = x * W + h * U + b
|
||||
Array1D ifco(b.d1());
|
||||
ifco.dotProduct(x, W)
|
||||
.add(Array1D(b.d1()).dotProduct(h, U))
|
||||
.add(b);
|
||||
Array1D ifco(b.d1(), status);
|
||||
{
|
||||
Array1D hU(b.d1(), status);
|
||||
if (U_FAILURE(status)) return;
|
||||
ifco.dotProduct(x, W)
|
||||
.add(hU.dotProduct(h, U))
|
||||
.add(b);
|
||||
} // delocate hU
|
||||
|
||||
int32_t hunits = b.d1() / 4;
|
||||
ifco.slice(0*hunits, hunits).sigmoid(); // i: sigmod
|
||||
|
@ -597,10 +611,14 @@ void compute(
|
|||
ifco.slice(2*hunits, hunits).tanh(); // c_: tanh
|
||||
ifco.slice(3*hunits, hunits).sigmoid(); // o: sigmod
|
||||
|
||||
c.hadamardProduct(ifco.slice(hunits, hunits))
|
||||
.add(Array1D(c.d1())
|
||||
.assign(ifco.slice(0, hunits))
|
||||
.hadamardProduct(ifco.slice(2*hunits, hunits)));
|
||||
{
|
||||
Array1D ic(c.d1(), status);
|
||||
if (U_FAILURE(status)) return;
|
||||
c.hadamardProduct(ifco.slice(hunits, hunits))
|
||||
.add(ic
|
||||
.assign(ifco.slice(0, hunits))
|
||||
.hadamardProduct(ifco.slice(2*hunits, hunits)));
|
||||
}
|
||||
|
||||
h.assign(c)
|
||||
.tanh()
|
||||
|
@ -614,10 +632,12 @@ static const int32_t MIN_WORD = 2;
|
|||
static const int32_t MIN_WORD_SPAN = MIN_WORD * 2;
|
||||
|
||||
int32_t
|
||||
LSTMBreakEngine::divideUpDictionaryRange(UText *text,
|
||||
int32_t startPos,
|
||||
int32_t endPos,
|
||||
UVector32 &foundBreaks) const {
|
||||
LSTMBreakEngine::divideUpDictionaryRange( UText *text,
|
||||
int32_t startPos,
|
||||
int32_t endPos,
|
||||
UVector32 &foundBreaks,
|
||||
UErrorCode& status) const {
|
||||
if (U_FAILURE(status)) return 0;
|
||||
int32_t beginFoundBreakSize = foundBreaks.size();
|
||||
utext_setNativeIndex(text, startPos);
|
||||
utext_moveIndex32(text, MIN_WORD_SPAN);
|
||||
|
@ -625,11 +645,12 @@ LSTMBreakEngine::divideUpDictionaryRange(UText *text,
|
|||
return 0; // Not enough characters for two words
|
||||
}
|
||||
utext_setNativeIndex(text, startPos);
|
||||
UErrorCode status = U_ZERO_ERROR;
|
||||
|
||||
UVector32 offsets(status);
|
||||
UVector32 indices(status);
|
||||
if (U_FAILURE(status)) return 0;
|
||||
fVectorizer->vectorize(text, startPos, endPos, offsets, indices, status);
|
||||
if (U_FAILURE(status)) return 0;
|
||||
int32_t* offsetsBuf = offsets.getBuffer();
|
||||
int32_t* indicesBuf = indices.getBuffer();
|
||||
|
||||
|
@ -640,12 +661,13 @@ LSTMBreakEngine::divideUpDictionaryRange(UText *text,
|
|||
// Python or ICU4X implementation. We first perform the Backward LSTM
|
||||
// and then merge the iteration of the forward LSTM and the output layer
|
||||
// together because we only neetdto remember the h[t-1] for Forward LSTM.
|
||||
Array1D c(hunits);
|
||||
Array1D c(hunits, status);
|
||||
|
||||
// TODO: limit size of hBackward. If input_seq_len is too big, we could
|
||||
// run out of memory.
|
||||
// Backward LSTM
|
||||
Array2D hBackward(input_seq_len, hunits);
|
||||
Array2D hBackward(input_seq_len, hunits, status);
|
||||
if (U_FAILURE(status)) return 0;
|
||||
for (int32_t i = input_seq_len - 1; i >= 0; i--) {
|
||||
Array1D hRow = hBackward.row(i);
|
||||
if (i != input_seq_len - 1) {
|
||||
|
@ -660,13 +682,15 @@ LSTMBreakEngine::divideUpDictionaryRange(UText *text,
|
|||
#endif // LSTM_DEBUG
|
||||
compute(fData->fBackwardW, fData->fBackwardU, fData->fBackwardB,
|
||||
fData->fEmbedding.row(indicesBuf[i]),
|
||||
hRow, c);
|
||||
hRow, c, status);
|
||||
if (U_FAILURE(status)) return 0;
|
||||
}
|
||||
|
||||
Array1D logp(4);
|
||||
Array1D logp(4, status);
|
||||
|
||||
// Allocate fbRow and slice the internal array in two.
|
||||
Array1D fbRow(2 * hunits);
|
||||
Array1D fbRow(2 * hunits, status);
|
||||
if (U_FAILURE(status)) return 0;
|
||||
Array1D forwardRow = fbRow.slice(0, hunits); // point to first half of data in fbRow.
|
||||
Array1D backwardRow = fbRow.slice(hunits, hunits); // point to second half of data n fbRow.
|
||||
|
||||
|
@ -683,7 +707,8 @@ LSTMBreakEngine::divideUpDictionaryRange(UText *text,
|
|||
// of fbRow.
|
||||
compute(fData->fForwardW, fData->fForwardU, fData->fForwardB,
|
||||
fData->fEmbedding.row(indicesBuf[i]),
|
||||
forwardRow, c);
|
||||
forwardRow, c, status);
|
||||
if (U_FAILURE(status)) return 0;
|
||||
|
||||
// assign the data from hBackward.row(i) to second half of fbRowa.
|
||||
backwardRow.assign(hBackward.row(i));
|
||||
|
@ -702,6 +727,7 @@ LSTMBreakEngine::divideUpDictionaryRange(UText *text,
|
|||
if (current == BEGIN || current == SINGLE) {
|
||||
if (i != 0) {
|
||||
foundBreaks.addElement(offsetsBuf[i], status);
|
||||
if (U_FAILURE(status)) return 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -759,14 +785,13 @@ U_CAPI const LSTMData* U_EXPORT2 CreateLSTMDataForScript(UScriptCode script, UEr
|
|||
return nullptr;
|
||||
}
|
||||
UnicodeString name = defaultLSTM(script, status);
|
||||
if (U_FAILURE(status)) return nullptr;
|
||||
CharString namebuf;
|
||||
namebuf.appendInvariantChars(name, status).truncate(namebuf.lastIndexOf('.'));
|
||||
|
||||
LocalUResourceBundlePointer rb(
|
||||
ures_openDirect(U_ICUDATA_BRKITR, namebuf.data(), &status));
|
||||
if (U_FAILURE(status)) {
|
||||
return nullptr;
|
||||
}
|
||||
if (U_FAILURE(status)) return nullptr;
|
||||
|
||||
return CreateLSTMData(rb.getAlias(), status);
|
||||
}
|
||||
|
|
|
@ -55,12 +55,14 @@ protected:
|
|||
* @param rangeStart The start of the range of dictionary characters
|
||||
* @param rangeEnd The end of the range of dictionary characters
|
||||
* @param foundBreaks Output of C array of int32_t break positions, or 0
|
||||
* @param status Information on any errors encountered.
|
||||
* @return The number of breaks found
|
||||
*/
|
||||
virtual int32_t divideUpDictionaryRange(UText *text,
|
||||
int32_t rangeStart,
|
||||
int32_t rangeEnd,
|
||||
UVector32 &foundBreaks ) const;
|
||||
UVector32 &foundBreaks,
|
||||
UErrorCode& status) const;
|
||||
private:
|
||||
const LSTMData* fData;
|
||||
const Vectorizer* fVectorizer;
|
||||
|
|
|
@ -163,7 +163,7 @@ void RuleBasedBreakIterator::DictionaryCache::populateDictionary(int32_t startPo
|
|||
// Ask the language object if there are any breaks. It will add them to the cache and
|
||||
// leave the text pointer on the other side of its range, ready to search for the next one.
|
||||
if (lbe != NULL) {
|
||||
foundBreakCount += lbe->findBreaks(text, rangeStart, rangeEnd, fBreaks);
|
||||
foundBreakCount += lbe->findBreaks(text, rangeStart, rangeEnd, fBreaks, status);
|
||||
}
|
||||
|
||||
// Reload the loop variables for the next go-round
|
||||
|
|
|
@ -151,7 +151,11 @@ void LSTMBETest::runTestFromFile(const char* filename) {
|
|||
dataerrln("%s:%d Error %s Could not allocate UVextor32", __FILE__, __LINE__, u_errorName(status));
|
||||
return;
|
||||
}
|
||||
engine->findBreaks(&ut, 0, value.length(), actual);
|
||||
engine->findBreaks(&ut, 0, value.length(), actual, status);
|
||||
if (U_FAILURE(status)) {
|
||||
dataerrln("%s:%d Error %s findBreaks failed", __FILE__, __LINE__, u_errorName(status));
|
||||
return;
|
||||
}
|
||||
utext_close(&ut);
|
||||
for (int32_t i = 0; i < actual.size(); i++) {
|
||||
ss << actual.elementAti(i) << ", ";
|
||||
|
|
Loading…
Add table
Reference in a new issue