Ruby: fixed string freezing for JRuby.
This commit is contained in:
parent
ff7f68ae9f
commit
d07a9963df
5 changed files with 30 additions and 25 deletions
|
@ -148,8 +148,8 @@ public class RubyMap extends RubyObject {
|
|||
*/
|
||||
@JRubyMethod(name = "[]=")
|
||||
public IRubyObject indexSet(ThreadContext context, IRubyObject key, IRubyObject value) {
|
||||
Utils.checkType(context, keyType, key, (RubyModule) valueTypeClass);
|
||||
Utils.checkType(context, valueType, value, (RubyModule) valueTypeClass);
|
||||
key = Utils.checkType(context, keyType, key, (RubyModule) valueTypeClass);
|
||||
value = Utils.checkType(context, valueType, value, (RubyModule) valueTypeClass);
|
||||
IRubyObject symbol;
|
||||
if (valueType == Descriptors.FieldDescriptor.Type.ENUM &&
|
||||
Utils.isRubyNum(value) &&
|
||||
|
|
|
@ -504,7 +504,7 @@ public class RubyMessage extends RubyObject {
|
|||
break;
|
||||
case BYTES:
|
||||
case STRING:
|
||||
Utils.validateStringEncoding(context.runtime, fieldDescriptor.getType(), value);
|
||||
Utils.validateStringEncoding(context, fieldDescriptor.getType(), value);
|
||||
RubyString str = (RubyString) value;
|
||||
switch (fieldDescriptor.getType()) {
|
||||
case BYTES:
|
||||
|
@ -695,7 +695,7 @@ public class RubyMessage extends RubyObject {
|
|||
}
|
||||
}
|
||||
if (addValue) {
|
||||
Utils.checkType(context, fieldType, value, (RubyModule) typeClass);
|
||||
value = Utils.checkType(context, fieldType, value, (RubyModule) typeClass);
|
||||
this.fields.put(fieldDescriptor, value);
|
||||
} else {
|
||||
this.fields.remove(fieldDescriptor);
|
||||
|
|
|
@ -110,7 +110,7 @@ public class RubyRepeatedField extends RubyObject {
|
|||
@JRubyMethod(name = "[]=")
|
||||
public IRubyObject indexSet(ThreadContext context, IRubyObject index, IRubyObject value) {
|
||||
int arrIndex = normalizeArrayIndex(index);
|
||||
Utils.checkType(context, fieldType, value, (RubyModule) typeClass);
|
||||
value = Utils.checkType(context, fieldType, value, (RubyModule) typeClass);
|
||||
IRubyObject defaultValue = defaultValue(context);
|
||||
for (int i = this.storage.size(); i < arrIndex; i++) {
|
||||
this.storage.set(i, defaultValue);
|
||||
|
@ -166,7 +166,7 @@ public class RubyRepeatedField extends RubyObject {
|
|||
public IRubyObject push(ThreadContext context, IRubyObject value) {
|
||||
if (!(fieldType == Descriptors.FieldDescriptor.Type.MESSAGE &&
|
||||
value == context.runtime.getNil())) {
|
||||
Utils.checkType(context, fieldType, value, (RubyModule) typeClass);
|
||||
value = Utils.checkType(context, fieldType, value, (RubyModule) typeClass);
|
||||
}
|
||||
this.storage.add(value);
|
||||
return this.storage;
|
||||
|
|
|
@ -64,8 +64,8 @@ public class Utils {
|
|||
return context.runtime.newSymbol(typeName.replace("TYPE_", "").toLowerCase());
|
||||
}
|
||||
|
||||
public static void checkType(ThreadContext context, Descriptors.FieldDescriptor.Type fieldType,
|
||||
IRubyObject value, RubyModule typeClass) {
|
||||
public static IRubyObject checkType(ThreadContext context, Descriptors.FieldDescriptor.Type fieldType,
|
||||
IRubyObject value, RubyModule typeClass) {
|
||||
Ruby runtime = context.runtime;
|
||||
Object val;
|
||||
switch(fieldType) {
|
||||
|
@ -106,7 +106,7 @@ public class Utils {
|
|||
break;
|
||||
case BYTES:
|
||||
case STRING:
|
||||
validateStringEncoding(context.runtime, fieldType, value);
|
||||
value = validateStringEncoding(context, fieldType, value);
|
||||
break;
|
||||
case MESSAGE:
|
||||
if (value.getMetaClass() != typeClass) {
|
||||
|
@ -127,6 +127,7 @@ public class Utils {
|
|||
default:
|
||||
break;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
public static IRubyObject wrapPrimaryValue(ThreadContext context, Descriptors.FieldDescriptor.Type fieldType, Object value) {
|
||||
|
@ -148,10 +149,16 @@ public class Utils {
|
|||
return runtime.newFloat((Double) value);
|
||||
case BOOL:
|
||||
return (Boolean) value ? runtime.getTrue() : runtime.getFalse();
|
||||
case BYTES:
|
||||
return runtime.newString(((ByteString) value).toStringUtf8());
|
||||
case STRING:
|
||||
return runtime.newString(value.toString());
|
||||
case BYTES: {
|
||||
IRubyObject wrapped = runtime.newString(((ByteString) value).toStringUtf8());
|
||||
wrapped.setFrozen(true);
|
||||
return wrapped;
|
||||
}
|
||||
case STRING: {
|
||||
IRubyObject wrapped = runtime.newString(value.toString());
|
||||
wrapped.setFrozen(true);
|
||||
return wrapped;
|
||||
}
|
||||
default:
|
||||
return runtime.getNil();
|
||||
}
|
||||
|
@ -180,25 +187,21 @@ public class Utils {
|
|||
}
|
||||
}
|
||||
|
||||
public static void validateStringEncoding(Ruby runtime, Descriptors.FieldDescriptor.Type type, IRubyObject value) {
|
||||
public static IRubyObject validateStringEncoding(ThreadContext context, Descriptors.FieldDescriptor.Type type, IRubyObject value) {
|
||||
if (!(value instanceof RubyString))
|
||||
throw runtime.newTypeError("Invalid argument for string field.");
|
||||
Encoding encoding = ((RubyString) value).getEncoding();
|
||||
throw context.runtime.newTypeError("Invalid argument for string field.");
|
||||
switch(type) {
|
||||
case BYTES:
|
||||
if (encoding != ASCIIEncoding.INSTANCE)
|
||||
throw runtime.newTypeError("Encoding for bytes fields" +
|
||||
" must be \"ASCII-8BIT\", but was " + encoding);
|
||||
value = ((RubyString)value).encode(context, context.runtime.evalScriptlet("Encoding::ASCII_8BIT"));
|
||||
break;
|
||||
case STRING:
|
||||
if (encoding != UTF8Encoding.INSTANCE
|
||||
&& encoding != USASCIIEncoding.INSTANCE)
|
||||
throw runtime.newTypeError("Encoding for string fields" +
|
||||
" must be \"UTF-8\" or \"ASCII\", but was " + encoding);
|
||||
value = ((RubyString)value).encode(context, context.runtime.evalScriptlet("Encoding::UTF_8"));
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
value.setFrozen(true);
|
||||
return value;
|
||||
}
|
||||
|
||||
public static void checkNameAvailability(ThreadContext context, String name) {
|
||||
|
|
|
@ -861,8 +861,10 @@ module BasicTest
|
|||
m2 = TestMessage.decode_json(json)
|
||||
assert_equal 'foo', m2.optional_string
|
||||
assert_equal ['bar1', 'bar2'], m2.repeated_string
|
||||
assert m2.optional_string.frozen?
|
||||
assert m2.repeated_string[0].frozen?
|
||||
if RUBY_PLATFORM != "java"
|
||||
assert m2.optional_string.frozen?
|
||||
assert m2.repeated_string[0].frozen?
|
||||
end
|
||||
|
||||
proto = m.to_proto
|
||||
m2 = TestMessage.decode(proto)
|
||||
|
|
Loading…
Add table
Reference in a new issue