diff --git a/codegen.py b/codegen.py index 7ef06b4141..3a67ce435d 100755 --- a/codegen.py +++ b/codegen.py @@ -368,6 +368,9 @@ def printGetter(fieldType, fieldName): print(" public int getClassId() { return %i; }" % (c.index)) print(" public String getClassName() { return \"%s\"; }" % (c.name)) + if c.fields: + equalsHashCode(spec, c.fields, java_class_name(c.name), 'Properties', False) + printPropertiesBuilder(c) #accessor methods @@ -400,6 +403,49 @@ def printPropertiesClasses(): #-------------------------------------------------------------------------------- +def equalsHashCode(spec, fields, jClassName, classSuffix, usePrimitiveType): + print() + print() + print(" @Override") + print(" public boolean equals(Object o) {") + print(" if (this == o)") + print(" return true;") + print(" if (o == null || getClass() != o.getClass())") + print(" return false;") + print(" %s%s that = (%s%s) o;" % (jClassName, classSuffix, jClassName, classSuffix)) + + for f in fields: + (fType, fName) = (java_field_type(spec, f.domain), java_field_name(f.name)) + if usePrimitiveType and fType in javaScalarTypes: + print(" if (%s != that.%s)" % (fName, fName)) + else: + print(" if (%s != null ? !%s.equals(that.%s) : that.%s != null)" % (fName, fName, fName, fName)) + + print(" return false;") + + print(" return true;") + print(" }") + + print() + print(" @Override") + print(" public int hashCode() {") + print(" int result = 0;") + + for f in fields: + (fType, fName) = (java_field_type(spec, f.domain), java_field_name(f.name)) + if usePrimitiveType and fType in javaScalarTypes: + if fType == 'boolean': + print(" result = 31 * result + (%s ? 1 : 0);" % fName) + elif fType == 'long': + print(" result = 31 * result + (int) (%s ^ (%s >>> 32));" % (fName, fName)) + else: + print(" result = 31 * result + %s;" % fName) + else: + print(" result = 31 * result + (%s != null ? %s.hashCode() : 0);" % (fName, fName)) + + print(" return result;") + print(" }") + def genJavaImpl(spec): def printHeader(): printFileHeader() @@ -503,6 +549,8 @@ def write_arguments(): getters() constructors() others() + if m.arguments: + equalsHashCode(spec, m.arguments, java_class_name(m.name), '', True) argument_debug_string() write_arguments() diff --git a/src/test/java/com/rabbitmq/client/test/ClientTests.java b/src/test/java/com/rabbitmq/client/test/ClientTests.java index 2b4c8ea04a..c299764f31 100644 --- a/src/test/java/com/rabbitmq/client/test/ClientTests.java +++ b/src/test/java/com/rabbitmq/client/test/ClientTests.java @@ -64,7 +64,8 @@ JsonRpcTest.class, AddressTest.class, DefaultRetryHandlerTest.class, - NioDeadlockOnConnectionClosing.class + NioDeadlockOnConnectionClosing.class, + GeneratedClassesTest.class }) public class ClientTests { diff --git a/src/test/java/com/rabbitmq/client/test/GeneratedClassesTest.java b/src/test/java/com/rabbitmq/client/test/GeneratedClassesTest.java new file mode 100644 index 0000000000..e9dbcfca98 --- /dev/null +++ b/src/test/java/com/rabbitmq/client/test/GeneratedClassesTest.java @@ -0,0 +1,135 @@ +// Copyright (c) 2018 Pivotal Software, Inc. All rights reserved. +// +// This software, the RabbitMQ Java client library, is triple-licensed under the +// Mozilla Public License 1.1 ("MPL"), the GNU General Public License version 2 +// ("GPL") and the Apache License version 2 ("ASL"). For the MPL, please see +// LICENSE-MPL-RabbitMQ. For the GPL, please see LICENSE-GPL2. For the ASL, +// please see LICENSE-APACHE2. +// +// This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY KIND, +// either express or implied. See the LICENSE file for specific language governing +// rights and limitations of this software. +// +// If you have any questions regarding licensing, please contact us at +// info@rabbitmq.com. + +package com.rabbitmq.client.test; + +import com.rabbitmq.client.AMQP; +import com.rabbitmq.client.impl.AMQImpl; +import org.junit.Test; + +import java.util.Calendar; +import java.util.Date; + +import static java.util.Collections.singletonMap; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; + +/** + * + */ +public class GeneratedClassesTest { + + @Test + public void amqpPropertiesEqualsHashCode() { + checkEquals( + new AMQP.BasicProperties.Builder().correlationId("one").build(), + new AMQP.BasicProperties.Builder().correlationId("one").build() + ); + checkNotEquals( + new AMQP.BasicProperties.Builder().correlationId("one").build(), + new AMQP.BasicProperties.Builder().correlationId("two").build() + ); + Date date = Calendar.getInstance().getTime(); + checkEquals( + new AMQP.BasicProperties.Builder() + .deliveryMode(1) + .headers(singletonMap("one", "two")) + .correlationId("123") + .expiration("later") + .priority(10) + .replyTo("me") + .contentType("text/plain") + .contentEncoding("UTF-8") + .userId("jdoe") + .appId("app1") + .clusterId("cluster1") + .messageId("message123") + .timestamp(date) + .type("type") + .build(), + new AMQP.BasicProperties.Builder() + .deliveryMode(1) + .headers(singletonMap("one", "two")) + .correlationId("123") + .expiration("later") + .priority(10) + .replyTo("me") + .contentType("text/plain") + .contentEncoding("UTF-8") + .userId("jdoe") + .appId("app1") + .clusterId("cluster1") + .messageId("message123") + .timestamp(date) + .type("type") + .build() + ); + checkNotEquals( + new AMQP.BasicProperties.Builder() + .deliveryMode(1) + .headers(singletonMap("one", "two")) + .correlationId("123") + .expiration("later") + .priority(10) + .replyTo("me") + .contentType("text/plain") + .contentEncoding("UTF-8") + .userId("jdoe") + .appId("app1") + .clusterId("cluster1") + .messageId("message123") + .timestamp(date) + .type("type") + .build(), + new AMQP.BasicProperties.Builder() + .deliveryMode(2) + .headers(singletonMap("one", "two")) + .correlationId("123") + .expiration("later") + .priority(10) + .replyTo("me") + .contentType("text/plain") + .contentEncoding("UTF-8") + .userId("jdoe") + .appId("app1") + .clusterId("cluster1") + .messageId("message123") + .timestamp(date) + .type("type") + .build() + ); + + } + + @Test public void amqImplEqualsHashCode() { + checkEquals( + new AMQImpl.Basic.Deliver("tag", 1L, false, "amq.direct","rk"), + new AMQImpl.Basic.Deliver("tag", 1L, false, "amq.direct","rk") + ); + checkNotEquals( + new AMQImpl.Basic.Deliver("tag", 1L, false, "amq.direct","rk"), + new AMQImpl.Basic.Deliver("tag", 2L, false, "amq.direct","rk") + ); + } + + private void checkEquals(Object o1, Object o2) { + assertEquals(o1, o2); + assertEquals(o1.hashCode(), o2.hashCode()); + } + + private void checkNotEquals(Object o1, Object o2) { + assertNotEquals(o1, o2); + } +}