diff --git a/javanano/src/main/java/com/google/protobuf/nano/InternalNano.java b/javanano/src/main/java/com/google/protobuf/nano/InternalNano.java index a9a459dd..044c30dd 100644 --- a/javanano/src/main/java/com/google/protobuf/nano/InternalNano.java +++ b/javanano/src/main/java/com/google/protobuf/nano/InternalNano.java @@ -491,4 +491,44 @@ public final class InternalNano { } return size; } + + /** + * Checks whether two {@link Map} are equal. We don't use the default equals + * method of {@link Map} because it compares by identity not by content for + * byte arrays. + */ + public static boolean equals(Map a, Map b) { + if (a == b) { + return true; + } + if (a == null) { + return b.size() == 0; + } + if (b == null) { + return a.size() == 0; + } + if (a.size() != b.size()) { + return false; + } + for (Entry entry : a.entrySet()) { + if (!b.containsKey(entry.getKey())) { + return false; + } + if (!equalsMapValue(entry.getValue(), b.get(entry.getKey()))) { + return false; + } + } + return true; + } + + private static boolean equalsMapValue(Object a, Object b) { + if (a == null || b == null) { + throw new IllegalStateException( + "keys and values in maps cannot be null"); + } + if (a instanceof byte[] && b instanceof byte[]) { + return Arrays.equals((byte[]) a, (byte[]) b); + } + return a.equals(b); + } } diff --git a/javanano/src/test/java/com/google/protobuf/nano/NanoTest.java b/javanano/src/test/java/com/google/protobuf/nano/NanoTest.java index 4159e662..aa279da2 100644 --- a/javanano/src/test/java/com/google/protobuf/nano/NanoTest.java +++ b/javanano/src/test/java/com/google/protobuf/nano/NanoTest.java @@ -31,6 +31,7 @@ package com.google.protobuf.nano; import com.google.protobuf.nano.MapTestProto.TestMap; +import com.google.protobuf.nano.MapTestProto.TestMap.MessageValue; import com.google.protobuf.nano.NanoAccessorsOuterClass.TestNanoAccessors; import com.google.protobuf.nano.NanoHasOuterClass.TestAllTypesNanoHas; import com.google.protobuf.nano.NanoOuterClass.TestAllTypesNano; @@ -47,6 +48,7 @@ import junit.framework.TestCase; import java.util.Arrays; import java.util.HashMap; import java.util.Map; +import java.util.TreeMap; /** * Test nano runtime. @@ -3824,15 +3826,107 @@ public class NanoTest extends TestCase { assertEquals(0, messageValue.value2); } + public void testMapEquals() throws Exception { + TestMap a = new TestMap(); + TestMap b = new TestMap(); + + // empty and null map fields are equal. + assertTestMapEqual(a, b); + a.int32ToBytesField = new HashMap(); + assertTestMapEqual(a, b); + + a.int32ToInt32Field = new HashMap(); + b.int32ToInt32Field = new HashMap(); + setMap(a.int32ToInt32Field, deepCopy(int32Values), deepCopy(int32Values)); + setMap(b.int32ToInt32Field, deepCopy(int32Values), deepCopy(int32Values)); + assertTestMapEqual(a, b); + + a.int32ToMessageField = + new HashMap(); + b.int32ToMessageField = + new HashMap(); + setMap(a.int32ToMessageField, + deepCopy(int32Values), deepCopy(messageValues)); + setMap(b.int32ToMessageField, + deepCopy(int32Values), deepCopy(messageValues)); + assertTestMapEqual(a, b); + + a.stringToInt32Field = new HashMap(); + b.stringToInt32Field = new HashMap(); + setMap(a.stringToInt32Field, deepCopy(stringValues), deepCopy(int32Values)); + setMap(b.stringToInt32Field, deepCopy(stringValues), deepCopy(int32Values)); + assertTestMapEqual(a, b); + + a.int32ToBytesField = new HashMap(); + b.int32ToBytesField = new HashMap(); + setMap(a.int32ToBytesField, deepCopy(int32Values), deepCopy(bytesValues)); + setMap(b.int32ToBytesField, deepCopy(int32Values), deepCopy(bytesValues)); + assertTestMapEqual(a, b); + + // Make sure the map implementation does not matter. + a.int32ToStringField = new TreeMap(); + b.int32ToStringField = new HashMap(); + setMap(a.int32ToStringField, deepCopy(int32Values), deepCopy(stringValues)); + setMap(b.int32ToStringField, deepCopy(int32Values), deepCopy(stringValues)); + assertTestMapEqual(a, b); + + a.clear(); + b.clear(); + + // unequal cases: different value + a.int32ToInt32Field = new HashMap(); + b.int32ToInt32Field = new HashMap(); + a.int32ToInt32Field.put(1, 1); + b.int32ToInt32Field.put(1, 2); + assertTestMapUnequal(a, b); + // unequal case: additional entry + b.int32ToInt32Field.put(1, 1); + b.int32ToInt32Field.put(2, 1); + assertTestMapUnequal(a, b); + a.int32ToInt32Field.put(2, 1); + assertTestMapEqual(a, b); + + // unequal case: different message value. + a.int32ToMessageField = + new HashMap(); + b.int32ToMessageField = + new HashMap(); + MessageValue va = new MessageValue(); + va.value = 1; + MessageValue vb = new MessageValue(); + vb.value = 1; + a.int32ToMessageField.put(1, va); + b.int32ToMessageField.put(1, vb); + assertTestMapEqual(a, b); + vb.value = 2; + assertTestMapUnequal(a, b); + } + + private static void assertTestMapEqual(TestMap a, TestMap b) + throws Exception { + assertEquals(a.hashCode(), b.hashCode()); + assertTrue(a.equals(b)); + assertTrue(b.equals(a)); + } + + private static void assertTestMapUnequal(TestMap a, TestMap b) + throws Exception { + assertFalse(a.equals(b)); + assertFalse(b.equals(a)); + } + private static final Integer[] int32Values = new Integer[] { 0, 1, -1, Integer.MAX_VALUE, Integer.MIN_VALUE, }; + private static final Long[] int64Values = new Long[] { 0L, 1L, -1L, Long.MAX_VALUE, Long.MIN_VALUE, }; + private static final String[] stringValues = new String[] { "", "hello", "world", "foo", "bar", }; + private static final byte[][] bytesValues = new byte[][] { new byte[] {}, new byte[] {0}, @@ -3840,13 +3934,16 @@ public class NanoTest extends TestCase { new byte[] {127, -128}, new byte[] {'a', 'b', '0', '1'}, }; + private static final Boolean[] boolValues = new Boolean[] { false, true, }; + private static final Integer[] enumValues = new Integer[] { TestMap.FOO, TestMap.BAR, TestMap.BAZ, TestMap.QUX, Integer.MAX_VALUE /* unknown */, }; + private static final TestMap.MessageValue[] messageValues = new TestMap.MessageValue[] { newMapValueMessage(0), @@ -3855,15 +3952,37 @@ public class NanoTest extends TestCase { newMapValueMessage(Integer.MAX_VALUE), newMapValueMessage(Integer.MIN_VALUE), }; + private static TestMap.MessageValue newMapValueMessage(int value) { TestMap.MessageValue result = new TestMap.MessageValue(); result.value = value; return result; } + @SuppressWarnings("unchecked") + private static T[] deepCopy(T[] orig) throws Exception { + if (orig instanceof MessageValue[]) { + MessageValue[] result = new MessageValue[orig.length]; + for (int i = 0; i < orig.length; i++) { + result[i] = new MessageValue(); + MessageNano.mergeFrom( + result[i], MessageNano.toByteArray((MessageValue) orig[i])); + } + return (T[]) result; + } + if (orig instanceof byte[][]) { + byte[][] result = new byte[orig.length][]; + for (int i = 0; i < orig.length; i++) { + byte[] origBytes = (byte[]) orig[i]; + result[i] = Arrays.copyOf(origBytes, origBytes.length); + } + } + return Arrays.copyOf(orig, orig.length); + } + private void setMap(Map map, K[] keys, V[] values) { assert(keys.length == values.length); - for (int i = 0; i < keys.length; ++i) { + for (int i = 0; i < keys.length; i++) { map.put(keys[i], values[i]); } } @@ -3871,7 +3990,7 @@ public class NanoTest extends TestCase { private void assertMapSet( Map map, K[] keys, V[] values) throws Exception { assert(keys.length == values.length); - for (int i = 0; i < values.length; ++i) { + for (int i = 0; i < values.length; i++) { assertEquals(values[i], map.get(keys[i])); } assertEquals(keys.length, map.size()); diff --git a/src/google/protobuf/compiler/javanano/javanano_map_field.cc b/src/google/protobuf/compiler/javanano/javanano_map_field.cc index 082573dd..c816fb3d 100644 --- a/src/google/protobuf/compiler/javanano/javanano_map_field.cc +++ b/src/google/protobuf/compiler/javanano/javanano_map_field.cc @@ -166,6 +166,11 @@ GenerateSerializedSizeCode(io::Printer* printer) const { void MapFieldGenerator:: GenerateEqualsCode(io::Printer* printer) const { + printer->Print(variables_, + "if (!com.google.protobuf.nano.InternalNano.equals(\n" + " this.$name$, other.$name$)) {\n" + " return false;\n" + "}\n"); } void MapFieldGenerator::