ICU-22100 Modify ML model to improve Japanese phrase breaking performance

This commit is contained in:
allenwtsu 2023-01-31 18:17:02 +08:00 committed by Frank Yung-Fong Tang
parent 5560ee8870
commit 3f05361b41
4 changed files with 1044 additions and 980 deletions

View file

@ -18,11 +18,12 @@
U_NAMESPACE_BEGIN
enum class ModelIndex { kUWStart = 0, kBWStart = 6, kTWStart = 9 };
MlBreakEngine::MlBreakEngine(const UnicodeSet &digitOrOpenPunctuationOrAlphabetSet,
const UnicodeSet &closePunctuationSet, UErrorCode &status)
const UnicodeSet &closePunctuationSet, UErrorCode &status)
: fDigitOrOpenPunctuationOrAlphabetSet(digitOrOpenPunctuationOrAlphabetSet),
fClosePunctuationSet(closePunctuationSet),
fModel(status),
fNegativeSum(0) {
if (U_FAILURE(status)) {
return;
@ -32,14 +33,10 @@ MlBreakEngine::MlBreakEngine(const UnicodeSet &digitOrOpenPunctuationOrAlphabetS
MlBreakEngine::~MlBreakEngine() {}
namespace {
const char16_t INVALID = u'|';
}
int32_t MlBreakEngine::divideUpRange(UText *inText, int32_t rangeStart, int32_t rangeEnd,
UVector32 &foundBreaks, const UnicodeString &inString,
const LocalPointer<UVector32> &inputMap,
UErrorCode &status) const {
UVector32 &foundBreaks, const UnicodeString &inString,
const LocalPointer<UVector32> &inputMap,
UErrorCode &status) const {
if (U_FAILURE(status)) {
return 0;
}
@ -53,30 +50,35 @@ int32_t MlBreakEngine::divideUpRange(UText *inText, int32_t rangeStart, int32_t
return 0;
}
int32_t numBreaks = 0;
UnicodeString index;
// The ML model groups six char to evaluate if the 4th char is a breakpoint.
// Like a sliding window, the elementList removes the first char and appends the new char from
// inString in each iteration so that its size always remains at six.
UChar32 elementList[6];
int32_t codeUts = initElementList(inString, elementList, status);
int32_t length = inString.countChar32();
int32_t codePointLength = inString.countChar32();
// The ML algorithm groups six char and evaluates whether the 4th char is a breakpoint.
// In each iteration, it evaluates the 4th char and then moves forward one char like a sliding
// window. Initially, the first six values in the indexList are [-1, -1, 0, 1, 2, 3]. After
// moving forward, finally the last six values in the indexList are
// [length-4, length-3, length-2, length-1, -1, -1]. The "+4" here means four extra "-1".
int32_t indexSize = codePointLength + 4;
int32_t *indexList = (int32_t *)uprv_malloc(indexSize * sizeof(int32_t));
if (indexList == nullptr) {
status = U_MEMORY_ALLOCATION_ERROR;
return 0;
}
int32_t numCodeUnits = initIndexList(inString, indexList, status);
// Add a break for the start.
boundary.addElement(0, status);
numBreaks++;
if (U_FAILURE(status)) return 0;
for (int32_t i = 1; i < length && U_SUCCESS(status); i++) {
evaluateBreakpoint(elementList, i, numBreaks, boundary, status);
if (i + 1 >= inString.countChar32()) break;
// Remove the first element and append a new element
uprv_memmove(elementList, elementList + 1, 5 * sizeof(UChar32));
elementList[5] = inString.countChar32(0, codeUts) < length ? inString.char32At(codeUts) : INVALID;
if (elementList[5] != INVALID) {
codeUts += U16_LENGTH(elementList[5]);
for (int32_t idx = 0; idx + 1 < codePointLength && U_SUCCESS(status); idx++) {
numBreaks =
evaluateBreakpoint(inString, indexList, idx, numCodeUnits, numBreaks, boundary, status);
if (idx + 4 < codePointLength) {
indexList[idx + 6] = numCodeUnits;
numCodeUnits += U16_LENGTH(inString.char32At(indexList[idx + 6]));
}
}
uprv_free(indexList);
if (U_FAILURE(status)) return 0;
// Add a break for the end if there is not one there already.
@ -128,119 +130,112 @@ int32_t MlBreakEngine::divideUpRange(UText *inText, int32_t rangeStart, int32_t
return correctedNumBreaks;
}
void MlBreakEngine::evaluateBreakpoint(UChar32* elementList, int32_t index, int32_t &numBreaks,
UVector32 &boundary, UErrorCode &status) const {
int32_t MlBreakEngine::evaluateBreakpoint(const UnicodeString &inString, int32_t *indexList,
int32_t startIdx, int32_t numCodeUnits, int32_t numBreaks,
UVector32 &boundary, UErrorCode &status) const {
if (U_FAILURE(status)) {
return;
return numBreaks;
}
UnicodeString feature;
int32_t start = 0, end = 0;
int32_t score = fNegativeSum;
if (elementList[0] != INVALID) {
// When the key doesn't exist, Hashtable.geti(key) returns 0 and 2 * 0 = 0.
// So, we can skip to check whether fModel includes key featureList[j] or not.
score += (2 * fModel.geti(feature.setTo(u"UW1:", 4).append(elementList[0])));
for (int i = 0; i < 6; i++) {
// UW1 ~ UW6
start = startIdx + i;
if (indexList[start] != -1) {
end = (indexList[start + 1] != -1) ? indexList[start + 1] : numCodeUnits;
score += fModel[static_cast<int32_t>(ModelIndex::kUWStart) + i].geti(
inString.tempSubString(indexList[start], end - indexList[start]));
}
}
if (elementList[1] != INVALID) {
score += (2 * fModel.geti(feature.setTo(u"UW2:", 4).append(elementList[1])));
for (int i = 0; i < 3; i++) {
// BW1 ~ BW3
start = startIdx + i + 1;
if (indexList[start] != -1 && indexList[start + 1] != -1) {
end = (indexList[start + 2] != -1) ? indexList[start + 2] : numCodeUnits;
score += fModel[static_cast<int32_t>(ModelIndex::kBWStart) + i].geti(
inString.tempSubString(indexList[start], end - indexList[start]));
}
}
if (elementList[2] != INVALID) {
score += (2 * fModel.geti(feature.setTo(u"UW3:", 4).append(elementList[2])));
}
if (elementList[3] != INVALID) {
score += (2 * fModel.geti(feature.setTo(u"UW4:", 4).append(elementList[3])));
}
if (elementList[4] != INVALID) {
score += (2 * fModel.geti(feature.setTo(u"UW5:", 4).append(elementList[4])));
}
if (elementList[5] != INVALID) {
score += (2 * fModel.geti(feature.setTo(u"UW6:", 4).append(elementList[5])));
}
if (elementList[1] != INVALID && elementList[2] != INVALID) {
score += (2 * fModel.geti(
feature.setTo(u"BW1:", 4).append(elementList[1]).append(elementList[2])));
}
if (elementList[2] != INVALID && elementList[3] != INVALID) {
score += (2 * fModel.geti(
feature.setTo(u"BW2:", 4).append(elementList[2]).append(elementList[3])));
}
if (elementList[3] != INVALID && elementList[4] != INVALID) {
score += (2 * fModel.geti(
feature.setTo(u"BW3:", 4).append(elementList[3]).append(elementList[4])));
}
if (elementList[0] != INVALID && elementList[1] != INVALID && elementList[2] != INVALID) {
score += (2 * fModel.geti(feature.setTo(u"TW1:", 4)
.append(elementList[0])
.append(elementList[1])
.append(elementList[2])));
}
if (elementList[1] != INVALID && elementList[2] != INVALID && elementList[3] != INVALID) {
score += (2 * fModel.geti(feature.setTo(u"TW2:", 4)
.append(elementList[1])
.append(elementList[2])
.append(elementList[3])));
}
if (elementList[2] != INVALID && elementList[3] != INVALID && elementList[4] != INVALID) {
score += (2 * fModel.geti(feature.setTo(u"TW3:", 4)
.append(elementList[2])
.append(elementList[3])
.append(elementList[4])));
}
if (elementList[3] != INVALID && elementList[4] != INVALID && elementList[5] != INVALID) {
score += (2 * fModel.geti(feature.setTo(u"TW4:", 4)
.append(elementList[3])
.append(elementList[4])
.append(elementList[5])));
for (int i = 0; i < 4; i++) {
// TW1 ~ TW4
start = startIdx + i;
if (indexList[start] != -1 && indexList[start + 1] != -1 && indexList[start + 2] != -1) {
end = (indexList[start + 3] != -1) ? indexList[start + 3] : numCodeUnits;
score += fModel[static_cast<int32_t>(ModelIndex::kTWStart) + i].geti(
inString.tempSubString(indexList[start], end - indexList[start]));
}
}
if (score > 0) {
boundary.addElement(index, status);
boundary.addElement(startIdx + 1, status);
numBreaks++;
}
return numBreaks;
}
int32_t MlBreakEngine::initElementList(const UnicodeString &inString, UChar32* elementList,
UErrorCode &status) const {
int32_t MlBreakEngine::initIndexList(const UnicodeString &inString, int32_t *indexList,
UErrorCode &status) const {
if (U_FAILURE(status)) {
return 0;
}
int32_t index = 0;
int32_t length = inString.countChar32();
UChar32 w1, w2, w3, w4, w5, w6;
w1 = w2 = w3 = w4 = w5 = w6 = INVALID;
// Set all (lenght+4) items inside indexLength to -1 presuming -1 is 4 bytes of 0xff.
uprv_memset(indexList, 0xff, (length + 4) * sizeof(int32_t));
if (length > 0) {
w3 = inString.char32At(0);
index += U16_LENGTH(w3);
indexList[2] = 0;
index = U16_LENGTH(inString.char32At(0));
if (length > 1) {
w4 = inString.char32At(index);
index += U16_LENGTH(w4);
indexList[3] = index;
index += U16_LENGTH(inString.char32At(index));
if (length > 2) {
w5 = inString.char32At(index);
index += U16_LENGTH(w5);
indexList[4] = index;
index += U16_LENGTH(inString.char32At(index));
if (length > 3) {
w6 = inString.char32At(index);
index += U16_LENGTH(w6);
indexList[5] = index;
index += U16_LENGTH(inString.char32At(index));
}
}
}
}
elementList[0] = w1;
elementList[1] = w2;
elementList[2] = w3;
elementList[3] = w4;
elementList[4] = w5;
elementList[5] = w6;
return index;
}
void MlBreakEngine::loadMLModel(UErrorCode &error) {
// BudouX's model consists of pairs of the feature and its score.
// As integrating it into jaml.txt, modelKeys denotes the ML feature; modelValues means the
// corresponding feature's score.
// BudouX's model consists of thirteen categories, each of which is make up of pairs of the
// feature and its score. As integrating it into jaml.txt, we define thirteen kinds of key and
// value to represent the feature and the corresponding score respectively.
if (U_FAILURE(error)) return;
UnicodeString key;
StackUResourceBundle stackTempBundle;
ResourceDataValue modelKey;
LocalUResourceBundlePointer rbp(ures_openDirect(U_ICUDATA_BRKITR, "jaml", &error));
UResourceBundle *rb = rbp.getAlias();
if (U_FAILURE(error)) return;
int32_t index = 0;
initKeyValue(rb, "UW1Keys", "UW1Values", fModel[index++], error);
initKeyValue(rb, "UW2Keys", "UW2Values", fModel[index++], error);
initKeyValue(rb, "UW3Keys", "UW3Values", fModel[index++], error);
initKeyValue(rb, "UW4Keys", "UW4Values", fModel[index++], error);
initKeyValue(rb, "UW5Keys", "UW5Values", fModel[index++], error);
initKeyValue(rb, "UW6Keys", "UW6Values", fModel[index++], error);
initKeyValue(rb, "BW1Keys", "BW1Values", fModel[index++], error);
initKeyValue(rb, "BW2Keys", "BW2Values", fModel[index++], error);
initKeyValue(rb, "BW3Keys", "BW3Values", fModel[index++], error);
initKeyValue(rb, "TW1Keys", "TW1Values", fModel[index++], error);
initKeyValue(rb, "TW2Keys", "TW2Values", fModel[index++], error);
initKeyValue(rb, "TW3Keys", "TW3Values", fModel[index++], error);
initKeyValue(rb, "TW4Keys", "TW4Values", fModel[index++], error);
fNegativeSum /= 2;
}
void MlBreakEngine::initKeyValue(UResourceBundle *rb, const char *keyName, const char *valueName,
Hashtable &model, UErrorCode &error) {
int32_t keySize = 0;
int32_t valueSize = 0;
int32_t stringLength = 0;
@ -248,15 +243,13 @@ void MlBreakEngine::loadMLModel(UErrorCode &error) {
StackUResourceBundle stackTempBundle;
ResourceDataValue modelKey;
LocalUResourceBundlePointer rbp(ures_openDirect(U_ICUDATA_BRKITR, "jaml", &error));
UResourceBundle* rb = rbp.orphan();
// get modelValues
LocalUResourceBundlePointer modelValue(ures_getByKey(rb, "modelValues", nullptr, &error));
const int32_t* value = ures_getIntVector(modelValue.getAlias(), &valueSize, &error);
LocalUResourceBundlePointer modelValue(ures_getByKey(rb, valueName, nullptr, &error));
const int32_t *value = ures_getIntVector(modelValue.getAlias(), &valueSize, &error);
if (U_FAILURE(error)) return;
// get modelKeys
ures_getValueWithFallback(rb, "modelKeys", stackTempBundle.getAlias(), modelKey, error);
ures_getValueWithFallback(rb, keyName, stackTempBundle.getAlias(), modelKey, error);
ResourceArray stringArray = modelKey.getArray(error);
keySize = stringArray.getSize();
if (U_FAILURE(error)) return;
@ -267,7 +260,7 @@ void MlBreakEngine::loadMLModel(UErrorCode &error) {
if (U_SUCCESS(error)) {
U_ASSERT(idx < valueSize);
fNegativeSum -= value[idx];
fModel.puti(key, value[idx], error);
model.puti(key, value[idx], error);
}
}
}

View file

@ -5,6 +5,7 @@
#define MLBREAKENGINE_H
#include "hash.h"
#include "unicode/resbund.h"
#include "unicode/uniset.h"
#include "unicode/utext.h"
#include "uvectr32.h"
@ -27,7 +28,7 @@ class MlBreakEngine : public UMemory {
* @param status Information on any errors encountered.
*/
MlBreakEngine(const UnicodeSet &digitOrOpenPunctuationOrAlphabetSet,
const UnicodeSet &closePunctuationSet, UErrorCode &status);
const UnicodeSet &closePunctuationSet, UErrorCode &status);
/**
* Virtual destructor.
@ -60,31 +61,50 @@ class MlBreakEngine : public UMemory {
void loadMLModel(UErrorCode &error);
/**
* Initialize the element list from the input string.
* In the machine learning's model file, specify the name of the key and value to load the
* corresponding feature and its score.
*
* @param rb A ResouceBundle corresponding to the model file.
* @param keyName The kay name in the model file.
* @param valueName The value name in the model file.
* @param model A hashtable to store the pairs of the feature and its score.
* @param error Information on any errors encountered.
*/
void initKeyValue(UResourceBundle *rb, const char *keyName, const char *valueName,
Hashtable &model, UErrorCode &error);
/**
* Initialize the index list from the input string.
*
* @param inString A input string to be segmented.
* @param elementList A list to store the first six characters.
* @param indexList A code unit index list of inString.
* @param status Information on any errors encountered.
* @return The number of code units of the first six characters in inString.
* @return The number of code units of the first four characters in inString.
*/
int32_t initElementList(const UnicodeString &inString, UChar32* elementList,
UErrorCode &status) const;
int32_t initIndexList(const UnicodeString &inString, int32_t *indexList,
UErrorCode &status) const;
/**
* Evaluate whether the index is a potential breakpoint.
*
* @param elementList A list including six elements for the breakpoint evaluation.
* @param index The breakpoint index to be evaluated.
* @param inString A input string to be segmented.
* @param indexList A code unit index list of the inString.
* @param startIdx The start index of the indexList.
* @param numCodeUnits The current code unit boundary of the indexList.
* @param numBreaks The accumulated number of breakpoints.
* @param boundary A vector including the index of the breakpoint.
* @param status Information on any errors encountered.
* @return The number of breakpoints
*/
void evaluateBreakpoint(UChar32* elementList, int32_t index, int32_t &numBreaks,
UVector32 &boundary, UErrorCode &status) const;
int32_t evaluateBreakpoint(const UnicodeString &inString, int32_t *indexList, int32_t startIdx,
int32_t numCodeUnits, int32_t numBreaks, UVector32 &boundary,
UErrorCode &status) const;
void printUnicodeString(const UnicodeString &s) const;
UnicodeSet fDigitOrOpenPunctuationOrAlphabetSet;
UnicodeSet fClosePunctuationSet;
Hashtable fModel;
Hashtable fModel[13]; // {UW1, UW2, ... UW6, BW1, ... BW3, TW1, TW2, ... TW4} 6+3+4= 13
int32_t fNegativeSum;
};

File diff suppressed because it is too large Load diff

View file

@ -8,26 +8,36 @@ import static com.ibm.icu.impl.CharacterIteration.current32;
import static com.ibm.icu.impl.CharacterIteration.next32;
import static com.ibm.icu.impl.CharacterIteration.previous32;
import com.ibm.icu.impl.Assert;
import com.ibm.icu.impl.ICUData;
import com.ibm.icu.lang.UCharacter;
import com.ibm.icu.text.UnicodeSet;
import com.ibm.icu.util.UResourceBundle;
import com.ibm.icu.util.UResourceBundleIterator;
import java.lang.System;
import java.text.CharacterIterator;
import java.util.Arrays;
import java.util.ArrayList;
import java.util.List;
import java.util.HashMap;
public class MlBreakEngine {
enum ModelIndex {
kUWStart(0), kBWStart(6), kTWStart(9);
private final int value;
private static final int INVALID = '|';
private static final String INVALID_STRING = "|";
private ModelIndex(int value) {
this.value = value;
}
public int getValue() {
return value;
}
}
public class MlBreakEngine {
// {UW1, UW2, ... UW6, BW1, ... BW3, TW1, TW2, ... TW4} 6+3+4= 13
private static final int MAX_FEATURE = 13;
private UnicodeSet fDigitOrOpenPunctuationOrAlphabetSet;
private UnicodeSet fClosePunctuationSet;
private HashMap<String, Integer> fModel;
private List<HashMap<String, Integer>> fModel;
private int fNegativeSum;
/**
@ -41,7 +51,10 @@ public class MlBreakEngine {
UnicodeSet closePunctuationSet) {
fDigitOrOpenPunctuationOrAlphabetSet = digitOrOpenPunctuationOrAlphabetSet;
fClosePunctuationSet = closePunctuationSet;
fModel = new HashMap<String, Integer>();
fModel = new ArrayList<HashMap<String, Integer>>(MAX_FEATURE);
for (int i = 0; i < MAX_FEATURE; i++) {
fModel.add(new HashMap<String, Integer>());
}
fNegativeSum = 0;
loadMLModel();
}
@ -49,42 +62,47 @@ public class MlBreakEngine {
/**
* Divide up a range of characters handled by this break engine.
*
* @param inText A input text.
* @param startPos The start index of the input text.
* @param endPos The end index of the input text.
* @param inString A input string normalized from inText from startPos to endPos
* @param numCodePts The number of code points of inString
* @param charPositions A map that transforms inString's code point index to code unit index.
* @param foundBreaks A list to store the breakpoint.
* @param inText An input text.
* @param startPos The start index of the input text.
* @param endPos The end index of the input text.
* @param inString A input string normalized from inText from startPos to endPos
* @param codePointLength The number of code points of inString
* @param charPositions A map that transforms inString's code point index to code unit index.
* @param foundBreaks A list to store the breakpoint.
* @return The number of breakpoints
*/
public int divideUpRange(CharacterIterator inText, int startPos, int endPos,
CharacterIterator inString, int numCodePts, int[] charPositions,
CharacterIterator inString, int codePointLength, int[] charPositions,
DictionaryBreakEngine.DequeI foundBreaks) {
if (startPos >= endPos) {
return 0;
}
ArrayList<Integer> boundary = new ArrayList<Integer>(numCodePts);
// The ML model groups six char to evaluate if the 4th char is a breakpoint.
// Like a sliding window, the elementList removes the first char and appends the new char
// from inString in each iteration so that its size always remains at six.
int elementList[] = new int[6];
initElementList(inString, elementList, numCodePts);
ArrayList<Integer> boundary = new ArrayList<Integer>(codePointLength);
String inputStr = transform(inString);
// The ML algorithm groups six char and evaluates whether the 4th char is a breakpoint.
// In each iteration, it evaluates the 4th char and then moves forward one char like
// sliding window. Initially, the first six values in the indexList are
// [-1, -1, 0, 1, 2, 3]. After moving forward, finally the last six values in the indexList
// are [length-4, length-3, length-2, length-1, -1, -1]. The "+4" here means four extra
// "-1".
int indexSize = codePointLength + 4;
int indexList[] = new int[indexSize];
int numCodeUnits = initIndexList(inString, indexList, codePointLength);
// Add a break for the start.
boundary.add(0, 0);
for (int i = 1; i < numCodePts; i++) {
evaluateBreakpoint(elementList, i, boundary);
if (i + 1 > numCodePts) {
break;
for (int idx = 0; idx + 1 < codePointLength; idx++) {
evaluateBreakpoint(inputStr, indexList, idx, numCodeUnits, boundary);
if (idx + 4 < codePointLength) {
indexList[idx + 6] = numCodeUnits;
numCodeUnits += Character.charCount(next32(inString));
}
shiftLeftOne(elementList);
elementList[5] = (i + 3) < numCodePts ? next32(inString) : INVALID;
}
// Add a break for the end if there is not one there already.
if (boundary.get(boundary.size() - 1) != numCodePts) {
boundary.add(numCodePts);
if (boundary.get(boundary.size() - 1) != codePointLength) {
boundary.add(codePointLength);
}
int correctedNumBreaks = 0;
@ -127,137 +145,94 @@ public class MlBreakEngine {
return correctedNumBreaks;
}
private void shiftLeftOne(int[] elementList) {
int length = elementList.length;
for (int i = 1; i < length; i++) {
elementList[i - 1] = elementList[i];
/**
* Transform a CharacterIterator into a String.
*/
private String transform(CharacterIterator inString) {
StringBuilder sb = new StringBuilder();
inString.setIndex(0);
for (char c = inString.first(); c != CharacterIterator.DONE; c = inString.next()) {
sb.append(c);
}
return sb.toString();
}
/**
* Evaluate whether the index is a potential breakpoint.
* Evaluate whether the breakpointIdx is a potential breakpoint.
*
* @param elementList A list including six elements for the breakpoint evaluation.
* @param index The breakpoint index to be evaluated.
* @param boundary An list including the index of the breakpoint.
* @param inputStr An input string to be segmented.
* @param indexList A code unit index list of the inputStr.
* @param startIdx The start index of the indexList.
* @param numCodeUnits The current code unit boundary of the indexList.
* @param boundary A list including the index of the breakpoint.
*/
private void evaluateBreakpoint(int[] elementList, int index, ArrayList<Integer> boundary) {
String[] featureList = new String[MAX_FEATURE];
final int w1 = elementList[0];
final int w2 = elementList[1];
final int w3 = elementList[2];
final int w4 = elementList[3];
final int w5 = elementList[4];
final int w6 = elementList[5];
StringBuilder sb = new StringBuilder();
int idx = 0;
if (w1 != INVALID) {
featureList[idx++] = sb.append("UW1:").appendCodePoint(w1).toString();
}
if (w2 != INVALID) {
sb.setLength(0);
featureList[idx++] = sb.append("UW2:").appendCodePoint(w2).toString();
}
if (w3 != INVALID) {
sb.setLength(0);
featureList[idx++] = sb.append("UW3:").appendCodePoint(w3).toString();
}
if (w4 != INVALID) {
sb.setLength(0);
featureList[idx++] = sb.append("UW4:").appendCodePoint(w4).toString();
}
if (w5 != INVALID) {
sb.setLength(0);
featureList[idx++] = sb.append("UW5:").appendCodePoint(w5).toString();
}
if (w6 != INVALID) {
sb.setLength(0);
featureList[idx++] = sb.append("UW6:").appendCodePoint(w6).toString();
}
if (w2 != INVALID && w3 != INVALID) {
sb.setLength(0);
featureList[idx++] = sb.append("BW1:").appendCodePoint(w2).appendCodePoint(
w3).toString();
}
if (w3 != INVALID && w4 != INVALID) {
sb.setLength(0);
featureList[idx++] = sb.append("BW2:").appendCodePoint(w3).appendCodePoint(
w4).toString();
}
if (w4 != INVALID && w5 != INVALID) {
sb.setLength(0);
featureList[idx++] = sb.append("BW3:").appendCodePoint(w4).appendCodePoint(
w5).toString();
}
if (w1 != INVALID && w2 != INVALID && w3 != INVALID) {
sb.setLength(0);
featureList[idx++] = sb.append("TW1:").appendCodePoint(w1).appendCodePoint(
w2).appendCodePoint(w3).toString();
}
if (w2 != INVALID && w3 != INVALID && w4 != INVALID) {
sb.setLength(0);
featureList[idx++] = sb.append("TW2:").appendCodePoint(w2).appendCodePoint(
w3).appendCodePoint(w4).toString();
}
if (w3 != INVALID && w4 != INVALID && w5 != INVALID) {
sb.setLength(0);
featureList[idx++] = sb.append("TW3:").appendCodePoint(w3).appendCodePoint(
w4).appendCodePoint(w5).toString();
}
if (w4 != INVALID && w5 != INVALID && w6 != INVALID) {
sb.setLength(0);
featureList[idx++] = sb.append("TW4:").appendCodePoint(w4).appendCodePoint(
w5).appendCodePoint(w6).toString();
}
private void evaluateBreakpoint(String inputStr, int[] indexList, int startIdx,
int numCodeUnits, ArrayList<Integer> boundary) {
int start = 0, end = 0;
int score = fNegativeSum;
for (int j = 0; j < idx; j++) {
if (fModel.containsKey(featureList[j])) {
score += (2 * fModel.get(featureList[j]));
for (int i = 0; i < 6; i++) {
// UW1 ~ UW6
start = startIdx + i;
if (indexList[start] != -1) {
end = (indexList[start + 1] != -1) ? indexList[start + 1] : numCodeUnits;
score += fModel.get(ModelIndex.kUWStart.getValue() + i).getOrDefault(
inputStr.substring(indexList[start], end), 0);
}
}
for (int i = 0; i < 3; i++) {
// BW1 ~ BW3
start = startIdx + i + 1;
if (indexList[start] != -1 && indexList[start + 1] != -1) {
end = (indexList[start + 2] != -1) ? indexList[start + 2] : numCodeUnits;
score += fModel.get(ModelIndex.kBWStart.getValue() + i).getOrDefault(
inputStr.substring(indexList[start], end), 0);
}
}
for (int i = 0; i < 4; i++) {
// TW1 ~ TW4
start = startIdx + i;
if (indexList[start] != -1
&& indexList[start + 1] != -1
&& indexList[start + 2] != -1) {
end = (indexList[start + 3] != -1) ? indexList[start + 3] : numCodeUnits;
score += fModel.get(ModelIndex.kTWStart.getValue() + i).getOrDefault(
inputStr.substring(indexList[start], end), 0);
}
}
if (score > 0) {
boundary.add(index);
boundary.add(startIdx + 1);
}
}
/**
* Initialize the element list from the input string.
* Initialize the index list from the input string.
*
* @param inString A input string to be segmented.
* @param elementList A list to store the first six characters.
* @param numCodePts The number of code points of input string
* @param inString An input string to be segmented.
* @param indexList A code unit index list of the inString.
* @param codePointLength The number of code points of the input string
* @return The number of the code units of the first six characters in inString.
*/
private int initElementList(CharacterIterator inString, int[] elementList, int numCodePts) {
private int initIndexList(CharacterIterator inString, int[] indexList, int codePointLength) {
int index = 0;
inString.setIndex(index);
int w1, w2, w3, w4, w5, w6;
w1 = w2 = w3 = w4 = w5 = w6 = INVALID;
if (numCodePts > 0) {
w3 = current32(inString);
index += Character.charCount(w3);
if (numCodePts > 1) {
w4 = next32(inString);
index += Character.charCount(w3);
if (numCodePts > 2) {
w5 = next32(inString);
index += Character.charCount(w5);
if (numCodePts > 3) {
w6 = next32(inString);
index += Character.charCount(w6);
Arrays.fill(indexList, -1);
if (codePointLength > 0) {
indexList[2] = 0;
index += Character.charCount(current32(inString));
if (codePointLength > 1) {
indexList[3] = index;
index += Character.charCount(next32(inString));
if (codePointLength > 2) {
indexList[4] = index;
index += Character.charCount(next32(inString));
if (codePointLength > 3) {
indexList[5] = index;
index += Character.charCount(next32(inString));
}
}
}
}
elementList[0] = w1;
elementList[1] = w2;
elementList[2] = w3;
elementList[3] = w4;
elementList[4] = w5;
elementList[5] = w6;
return index;
}
@ -268,13 +243,41 @@ public class MlBreakEngine {
int index = 0;
UResourceBundle rb = UResourceBundle.getBundleInstance(ICUData.ICU_BRKITR_BASE_NAME,
"jaml");
UResourceBundle keyBundle = rb.get("modelKeys");
UResourceBundle valueBundle = rb.get("modelValues");
initKeyValue(rb, "UW1Keys", "UW1Values", fModel.get(index++));
initKeyValue(rb, "UW2Keys", "UW2Values", fModel.get(index++));
initKeyValue(rb, "UW3Keys", "UW3Values", fModel.get(index++));
initKeyValue(rb, "UW4Keys", "UW4Values", fModel.get(index++));
initKeyValue(rb, "UW5Keys", "UW5Values", fModel.get(index++));
initKeyValue(rb, "UW6Keys", "UW6Values", fModel.get(index++));
initKeyValue(rb, "BW1Keys", "BW1Values", fModel.get(index++));
initKeyValue(rb, "BW2Keys", "BW2Values", fModel.get(index++));
initKeyValue(rb, "BW3Keys", "BW3Values", fModel.get(index++));
initKeyValue(rb, "TW1Keys", "TW1Values", fModel.get(index++));
initKeyValue(rb, "TW2Keys", "TW2Values", fModel.get(index++));
initKeyValue(rb, "TW3Keys", "TW3Values", fModel.get(index++));
initKeyValue(rb, "TW4Keys", "TW4Values", fModel.get(index++));
fNegativeSum /= 2;
}
/**
* In the machine learning's model file, specify the name of the key and value to load the
* corresponding feature and its score.
*
* @param rb A RedouceBundle corresponding to the model file.
* @param keyName The kay name in the model file.
* @param valueName The value name in the model file.
* @param map A HashMap to store the pairs of the feature and its score.
*/
private void initKeyValue(UResourceBundle rb, String keyName, String valueName,
HashMap<String, Integer> map) {
int idx = 0;
UResourceBundle keyBundle = rb.get(keyName);
UResourceBundle valueBundle = rb.get(valueName);
int[] value = valueBundle.getIntVector();
UResourceBundleIterator iterator = keyBundle.getIterator();
while (iterator.hasNext()) {
fNegativeSum -= value[index];
fModel.put(iterator.nextString(), value[index++]);
fNegativeSum -= value[idx];
map.put(iterator.nextString(), value[idx++]);
}
}
}