diff --git a/.github/workflows/ccpp.yml b/.github/workflows/ccpp.yml index 85fb0af..ea900f7 100644 --- a/.github/workflows/ccpp.yml +++ b/.github/workflows/ccpp.yml @@ -10,8 +10,7 @@ jobs: - name: cmake configure run: cmake . - name: cmake build - run: cmake --build . - + run: cmake --build . -j4 linux-build: runs-on: ubuntu-latest steps: @@ -19,6 +18,8 @@ jobs: - name: cmake configure run: cmake . - name: cmake build - run: cmake --build . -# - name: Run Tests -# run: cd tests && ctest --output-on-failure + run: cmake --build . -j4 + - name: network + run: sudo apt install net-tools && ifconfig + - name: Run Tests + run: cd tests && ctest -VV diff --git a/.gitignore b/.gitignore index 823ba94..9b631a8 100644 --- a/.gitignore +++ b/.gitignore @@ -272,3 +272,4 @@ tests/CMakeFiles/* Makefile cmake_install.cmake build/ +html/ diff --git a/CMakeLists.txt b/CMakeLists.txt index 5854c29..9bfa47a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,11 +25,13 @@ add_library (pubsub "src/Serialization.c" "src/System.c" "src/UDPTransport.c" + "src/Parameter.c" + "src/TCPTransport.c" ) if (UNIX) target_link_libraries(pubsub pubsub_msgs) else() -target_link_libraries(pubsub winmm) +target_link_libraries(pubsub winmm pubsub_msgs) endif() set_target_properties(pubsub PROPERTIES DEBUG_POSTFIX "d") target_include_directories(pubsub PUBLIC include/) @@ -46,6 +48,7 @@ add_library (pubsub_shared SHARED "src/Serialization.c" "src/System.c" "src/UDPTransport.c" + "src/TCPTransport.c" "src/Bindings.c") if (UNIX) target_link_libraries(pubsub_shared PUBLIC pubsub_msgs) diff --git a/CMakeSettings.json b/CMakeSettings.json deleted file mode 100644 index 8d94136..0000000 --- a/CMakeSettings.json +++ /dev/null @@ -1,49 +0,0 @@ -{ - // See https://go.microsoft.com//fwlink//?linkid=834763 for more information about this file. - "configurations": [ - { - "name": "x86-Debug", - "generator": "Ninja", - "configurationType": "Debug", - "inheritEnvironments": [ "msvc_x86" ], - "buildRoot": "${projectDir}\\build\\${name}", - "installRoot": "${env.USERPROFILE}\\CMakeBuilds\\${workspaceHash}\\install\\${name}", - "cmakeCommandArgs": "", - "buildCommandArgs": "-v", - "ctestCommandArgs": "" - }, - { - "name": "x86-Release", - "generator": "Ninja", - "configurationType": "RelWithDebInfo", - "inheritEnvironments": [ "msvc_x86" ], - "buildRoot": "${projectDir}\\build\\${name}", - "installRoot": "${env.USERPROFILE}\\CMakeBuilds\\${workspaceHash}\\install\\${name}", - "cmakeCommandArgs": "", - "buildCommandArgs": "-v", - "ctestCommandArgs": "" - }, - { - "name": "x64-Debug", - "generator": "Ninja", - "configurationType": "Debug", - "inheritEnvironments": [ "msvc_x64_x64" ], - "buildRoot": "${projectDir}\\build\\${name}", - "installRoot": "${env.USERPROFILE}\\CMakeBuilds\\${workspaceHash}\\install\\${name}", - "cmakeCommandArgs": "", - "buildCommandArgs": "-v", - "ctestCommandArgs": "" - }, - { - "name": "x64-Release", - "generator": "Ninja", - "configurationType": "RelWithDebInfo", - "inheritEnvironments": [ "msvc_x64_x64" ], - "buildRoot": "${projectDir}\\build\\${name}", - "installRoot": "${env.USERPROFILE}\\CMakeBuilds\\${workspaceHash}\\install\\${name}", - "cmakeCommandArgs": "", - "buildCommandArgs": "-v", - "ctestCommandArgs": "" - } - ] -} \ No newline at end of file diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 8c39ba4..4e8a61e 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -34,8 +34,7 @@ add_executable(read_param "read_param.c") target_link_libraries(read_param pubsub pubsub_msgs) add_dependencies(read_param pubsub pubsub_msgs) -#install(TARGETS pubsubtest -# ARCHIVE DESTINATION "s" -# LIBRARY DESTINATION "s" -# RUNTIME DESTINATION "s" -#) +add_executable(params_cpp "params.cpp") +target_link_libraries(params_cpp pubsub_cpp pubsub_msgs) +add_dependencies(params_cpp pubsub_cpp pubsub_msgs) +set_property(TARGET simple_sub_cpp PROPERTY CXX_STANDARD 11) diff --git a/examples/params.cpp b/examples/params.cpp new file mode 100644 index 0000000..4e96b55 --- /dev/null +++ b/examples/params.cpp @@ -0,0 +1,38 @@ +#include + +#include +#include + +#include + +int main() +{ + // Create the node + pubsub::Node node("simple_parameters"/*node name*/); + + // Adds TCP transport (optional) + struct ps_transport_t tcp_transport; + ps_tcp_transport_init(&tcp_transport, node.getNode()); + ps_node_add_transport(node.getNode(), &tcp_transport); + + // Create the "spinner" which executes callbacks and timers in a background thread + pubsub::BlockingSpinnerWithTimers spinner; + spinner.setNode(node);// Add the node to the spinner + + auto start = pubsub::Time::now();// Gets the current time + + auto parameter = node.parameter("test_1", 1.0, "Description."); + auto parameter2 = node.parameter("test_2", 3.0, "Description 2."); + + // Create a timer which will run at a prescribed interval + spinner.addTimer(1.0/*timer is run every this many seconds*/, [&]() + { + printf("Value: %f %f\n", (float)(double)parameter, (float)(double)parameter2); + }); + + // Wait for the spinner to exit (on control-c) + spinner.run(); + + return 0; +} + diff --git a/examples/simple_logger.cpp b/examples/simple_logger.cpp index 5d4ad89..6ac4b61 100644 --- a/examples/simple_logger.cpp +++ b/examples/simple_logger.cpp @@ -32,8 +32,7 @@ int main() // okay, since we are publishing with shared pointer we actually need to allocate the string properly auto shared = pubsub::msg::StringSharedPtr(new pubsub::msg::String); - shared->value = new char[strlen(msg.value) + 1]; - strcpy(shared->value, msg.value); + shared->value = msg.value; string_pub.publish(shared); msg.value = 0;// so it doesnt get freed by the destructor since we allocated it ourself diff --git a/examples/simple_pub.c b/examples/simple_pub.c index 79f4f76..cd49390 100644 --- a/examples/simple_pub.c +++ b/examples/simple_pub.c @@ -26,10 +26,10 @@ int main() // Create the publisher struct ps_pub_t string_pub; - ps_node_create_publisher(&node, "/data"/*topic name*/, + ps_node_create_publisher_ex(&node, "/data"/*topic name*/, &pubsub__String_def/*message definition*/, &string_pub, - true/*true to "latch" the topic*/); + true/*true to "latch" the topic*/, 1, NULL); // User is responsible for lifetime of the message they publish // Publish does a copy internally if necessary diff --git a/examples/simple_pub.cpp b/examples/simple_pub.cpp index 0f0c750..6bdc61a 100644 --- a/examples/simple_pub.cpp +++ b/examples/simple_pub.cpp @@ -35,7 +35,7 @@ int main() { auto now = pubsub::Time::now(); - // Build and publish the message + // Build and publish the message pubsub::msg::String msg; char value[20]; sprintf(value, "Hello %f", (now-start).toSec()); @@ -44,8 +44,7 @@ int main() }); // Wait for the spinner to exit (on control-c) - spinner.wait(); + spinner.run(); return 0; } - diff --git a/examples/simple_sub.c b/examples/simple_sub.c index 8f7fae3..03dd488 100644 --- a/examples/simple_sub.c +++ b/examples/simple_sub.c @@ -9,6 +9,15 @@ #include +struct ps_sub_t string_sub; +void callback(void* message, unsigned int size, void* cbdata, const struct ps_msg_info_t* info) +{ + // user is responsible for freeing the message and its arrays + struct pubsub__String* data = (struct pubsub__String*)message; + printf("Got message: %s\n", data->value); + pubsub__String_free(data, string_sub.allocator); +} + int main() { // Create the node @@ -23,10 +32,10 @@ int main() ps_node_add_transport(&node, &tcp_transport); // Create the subscriber - struct ps_sub_t string_sub; struct ps_subscriber_options options; ps_subscriber_options_init(&options); options.preferred_transport = PUBSUB_TCP_TRANSPORT;// sets preferred transport to TCP + options.cb = callback; ps_node_create_subscriber_adv(&node, "/data", &pubsub__String_def, &string_sub, &options); // Loop and spin @@ -36,18 +45,8 @@ int main() // Used to prevent this from using 100% CPU, but you can do that through other means ps_node_wait(&node, 1000/*maximum wait time in ms*/); - // Updates the node, which will queue up any received messages + // Updates the node, which will receive messages and call any callbacks as they come in ps_node_spin(&node); - - // our sub has a message definition, so the queue contains real messages - struct pubsub__String* data; - while (data = (struct pubsub__String*)ps_sub_deque(&string_sub)) - { - // user is responsible for freeing the message and its arrays - printf("Got message: %s\n", data->value); - free(data->value); - free(data); - } } // Shutdown the node to free resources diff --git a/examples/simple_sub.cpp b/examples/simple_sub.cpp index 4c0ff35..7e3fd26 100644 --- a/examples/simple_sub.cpp +++ b/examples/simple_sub.cpp @@ -15,7 +15,7 @@ int main() // Create the subscriber, the provided callback will be called each time a message comes in pubsub::Subscriber subscriber(node, "data"/*topic name*/, [](const pubsub::msg::StringSharedPtr& msg) { - printf("Got message %s\n", msg->value); + printf("Got message %s\n", msg->value.c_str()); }, 10/*maximum queue size, after this many messages build up the oldest will get dropped*/); // Create the "spinner" which executes callbacks and timers in a background thread diff --git a/high_level_api/Node.cpp b/high_level_api/Node.cpp index 7c83d0f..c3aef48 100644 --- a/high_level_api/Node.cpp +++ b/high_level_api/Node.cpp @@ -1 +1,10 @@ #include + +namespace pubsub +{ + std::mutex _publisher_mutex; + std::multimap _publishers; + std::multimap _subscribers; + + std::map _remappings; +} diff --git a/include/pubsub/Node.h b/include/pubsub/Node.h index 06b0bc9..8608a3b 100644 --- a/include/pubsub/Node.h +++ b/include/pubsub/Node.h @@ -29,7 +29,8 @@ struct ps_endpoint_t; struct ps_client_t; struct ps_subscribe_req_t; struct ps_allocator_t; -typedef void(*ps_transport_fn_pub_t)(struct ps_transport_t* transport, struct ps_pub_t* publisher, struct ps_client_t* client, const void* message, uint32_t length); +struct ps_msg_ref_t; +typedef void(*ps_transport_fn_pub_t)(struct ps_transport_t* transport, struct ps_pub_t* publisher, struct ps_client_t* client, struct ps_msg_ref_t* message); typedef int(*ps_transport_fn_spin_t)(struct ps_transport_t* transport, struct ps_node_t* node); typedef void(*ps_transport_fn_add_publisher_t)(struct ps_transport_t* transport, struct ps_pub_t* publisher); typedef void(*ps_transport_fn_remove_publisher_t)(struct ps_transport_t* transport, struct ps_pub_t* publisher); @@ -66,6 +67,7 @@ typedef void(*ps_param_confirm_cb_t)(const char* name, double value, void* data) struct ps_node_t { const char* name; + const char* description; unsigned int num_pubs; struct ps_pub_t** pubs; unsigned int num_subs; @@ -129,10 +131,9 @@ struct ps_msg_info_t struct ps_msg_header { uint8_t pid;//packet type id + uint32_t length;//message length uint32_t id;//stream id uint16_t seq;//sequence number - uint8_t index; - uint8_t count; }; #pragma pack(pop) @@ -161,20 +162,22 @@ struct ps_advertise_req_t uint32_t type_hash;// to see if the type is correct uint32_t group_id;// unique (hopefully) id that indicates which process this node is a part of }; -#pragma pack(pop) -#pragma pack(push) -#pragma pack(1) struct ps_subscribe_req_t { uint8_t id; int32_t addr; uint16_t port; }; -#pragma pack(pop) -#pragma pack(push) -#pragma pack(1) +struct ps_unsubscribe_req_t +{ + uint8_t id; + uint32_t addr; + uint16_t port; + uint32_t stream_id; +}; + struct ps_subscribe_accept_t { uint8_t pid;// packet type identifier @@ -193,24 +196,20 @@ void ps_node_init_ex(struct ps_node_t* node, const char* name, const char* ip, b void ps_node_create_publisher(struct ps_node_t* node, const char* topic, const struct ps_message_definition_t* type, struct ps_pub_t* pub, bool latched); -void ps_node_create_subscriber(struct ps_node_t* node, const char* topic, const struct ps_message_definition_t* type, - struct ps_sub_t* sub, - unsigned int queue_size,//make >= 1 - struct ps_allocator_t* allocator,//give null to use default - bool ignore_local);// if ignore local is set, this node ignores publications from itself - // this facilitiates passing messages through shared memory +void ps_node_create_publisher_ex(struct ps_node_t* node, const char* topic, const struct ps_message_definition_t* type, struct ps_pub_t* pub, bool latched, unsigned int recommended_transport, struct ps_allocator_t* allocator); typedef void(*ps_subscriber_fn_cb_t)(void* message, unsigned int size, void* data, const struct ps_msg_info_t* info); struct ps_subscriber_options { - unsigned int queue_size; bool ignore_local; struct ps_allocator_t* allocator; unsigned int skip;// skips to every nth message for throttling ps_subscriber_fn_cb_t cb; + ps_subscriber_fn_cb_t cb_raw; void* cb_data; - uint32_t preferred_transport;// falls back to udp otherwise + int32_t preferred_transport;// falls back to udp otherwise + const char* description; }; void ps_subscriber_options_init(struct ps_subscriber_options* options); @@ -219,7 +218,7 @@ void ps_node_create_subscriber_adv(struct ps_node_t* node, const char* topic, co struct ps_sub_t* sub, const struct ps_subscriber_options* options); -void ps_node_create_subscriber_cb(struct ps_node_t* node, const char* topic, const struct ps_message_definition_t* type, +void ps_node_create_subscriber(struct ps_node_t* node, const char* topic, const struct ps_message_definition_t* type, struct ps_sub_t* sub, ps_subscriber_fn_cb_t cb, void* cb_data, diff --git a/include/pubsub/Parameter.h b/include/pubsub/Parameter.h new file mode 100644 index 0000000..ee1a561 --- /dev/null +++ b/include/pubsub/Parameter.h @@ -0,0 +1,39 @@ +#pragma once + +#include + +#include +#include +#include +#include + +#include + +#include + +#ifdef __cplusplus +extern "C" +{ +#endif + +typedef void(*ps_param_fancy_cb_t)(const char* name, double value, void* data); + +struct ps_parameters +{ + struct pubsub__Parameters msg; + struct ps_pub_t param_pub; + ps_param_fancy_cb_t callback; + void* cb_data; +}; + +void ps_create_parameters(struct ps_node_t* node, struct ps_parameters* params_out, ps_param_fancy_cb_t callback, void* data); + +void ps_destroy_parameters(struct ps_parameters* params); + +void ps_add_parameter_double(struct ps_parameters* params, + const char* name, const char* description, + double value, double min, double max); + +#ifdef __cplusplus +} +#endif diff --git a/include/pubsub/Publisher.h b/include/pubsub/Publisher.h index 638fab5..6593c42 100644 --- a/include/pubsub/Publisher.h +++ b/include/pubsub/Publisher.h @@ -14,20 +14,20 @@ struct ps_message_definition_t; struct ps_endpoint_t { - unsigned short port; - int address; + unsigned short port; + int address; //bool multicast;// this is probably unnecessary }; // publisher client to network to struct ps_client_t { - struct ps_endpoint_t endpoint; - unsigned short sequence_number;// sequence of the networked packets, incremented with each one - unsigned long long last_keepalive;// timestamp of the last keepalive message, used to know when to deactiveate this connection - unsigned int stream_id;// user-unique identifier of what topic this came from - unsigned int modulo; - struct ps_transport_t* transport; + struct ps_endpoint_t endpoint; + unsigned short sequence_number;// sequence of the networked packets, incremented with each one + unsigned long long last_keepalive;// timestamp of the last keepalive message, used to know when to deactiveate this connection + unsigned int stream_id;// user-unique identifier of what topic this came from + unsigned int modulo; + struct ps_transport_t* transport; }; struct ps_pub_t @@ -39,10 +39,13 @@ struct ps_pub_t struct ps_node_t* node; unsigned int num_clients; struct ps_client_t* clients; + + struct ps_allocator_t* allocator; bool latched;// todo make this an enum of options if we add more + uint8_t recommended_transport; - struct ps_msg_t last_message;//only used if latched + struct ps_msg_ref_t* last_message;//only used if latched unsigned int sequence_number; }; diff --git a/include/pubsub/Serialization.h b/include/pubsub/Serialization.h index c27850c..0ceb7cb 100644 --- a/include/pubsub/Serialization.h +++ b/include/pubsub/Serialization.h @@ -8,14 +8,14 @@ extern "C" { #endif - struct ps_allocator_t - { - void*(*alloc)(unsigned int size, void* context); - void(*free)(void*); - void* context; - }; + struct ps_allocator_t + { + void*(*alloc)(unsigned int size, void* context); + void(*free)(void*, void* context); + void* context; + }; - extern struct ps_allocator_t ps_default_allocator; + extern struct ps_allocator_t ps_default_allocator; enum ps_field_types { @@ -31,8 +31,10 @@ extern "C" FT_Float32 = 9, FT_Float64 = 10, FT_MaxFloat,// all floats and ints are less than this, not present in messages - FT_String, + FT_String,// null terminated dynamic length string FT_Struct,//indicates the number of fields following contained in it + FT_StructDefinition, + FT_ArrayString// null terminated fixed length string }; typedef enum ps_field_types ps_field_types; @@ -49,8 +51,13 @@ extern "C" ps_field_types type; ps_field_flags flags;// packed in upper bits of type, but broken out here const char* name; - unsigned int length;//length of the array, 0 if dynamic - unsigned short content_length;//number of fields inside this struct + uint32_t length;//length of the array, 0 if dynamic, or if this is a struct definition, how many following fields are part of it + //context dependent field + union + { + uint16_t string_length;// this field is a ArrayString, the length of said string + uint16_t struct_index;// if this field is a struct, this is the index of the struct definition in the list of fields + }; }; struct ps_msg_enum_t @@ -60,17 +67,27 @@ extern "C" int field;// the field this is associated with in the message }; - // encoded message struct ps_msg_t { void* data; unsigned int len; }; + + struct ps_msg_ref_t + { + void* data; + unsigned int len; + unsigned int refcount; + }; + + void ps_msg_ref_add(struct ps_msg_ref_t* msg); + void ps_msg_ref_free(struct ps_msg_ref_t* msg, struct ps_allocator_t* allocator); struct ps_allocator_t; - typedef struct ps_msg_t(*ps_fn_encode_t)(struct ps_allocator_t* allocator, const void* msg); + typedef struct ps_msg_t(*ps_fn_encode_t)(const void* msg, struct ps_allocator_t* allocator); typedef void*(*ps_fn_decode_t)(const void* data, struct ps_allocator_t* allocator);// allocates the message + typedef void (*ps_fn_free_t)(void* msg, struct ps_allocator_t* allocator);// frees the message struct ps_message_definition_t { unsigned int hash; @@ -79,21 +96,22 @@ extern "C" struct ps_msg_field_t* fields; ps_fn_encode_t encode; ps_fn_decode_t decode; + ps_fn_free_t free; unsigned int num_enums; struct ps_msg_enum_t* enums; }; // Serializes a given message definition to a buffer. // Returns: Number of bytes written - int ps_serialize_message_definition(void* start, const struct ps_message_definition_t* definition); + int ps_serialize_message_definition(void* dst, const struct ps_message_definition_t* definition); // Deserializes a message definition from the specified buffer. - void ps_deserialize_message_definition(const void* start, struct ps_message_definition_t* definition); + void ps_deserialize_message_definition(const void* src, struct ps_message_definition_t* definition); // print out the deserialized contents of the message to console, for rostopic echo like implementations // in yaml format - // if field is non-null only print out the content of that field - void ps_deserialize_print(const void* data, const struct ps_message_definition_t* definition, unsigned int max_array_size, const char* field); + // if field_name is non-null only print out the content of that field + void ps_deserialize_print(const void* data, const struct ps_message_definition_t* definition, unsigned int max_array_size, const char* field_name); struct ps_deserialize_iterator { @@ -108,11 +126,11 @@ extern "C" // Create an iteratator to iterate through the fields of a serialized message // Returns: The iterator - struct ps_deserialize_iterator ps_deserialize_start(const char* msg, const struct ps_message_definition_t* definition); + struct ps_deserialize_iterator ps_deserialize_start(const void* msg, const struct ps_message_definition_t* definition); // Iterate through a serialized message one field at a time // Returns: Start pointer in the message for the current field or zero when at the end - const char* ps_deserialize_iterate(struct ps_deserialize_iterator* iter, const struct ps_msg_field_t** f, uint32_t* l); + const void* ps_deserialize_iterate(struct ps_deserialize_iterator* iter, const struct ps_msg_field_t** f, uint32_t* l); // Frees a dynamically allocated message definition void ps_free_message_definition(struct ps_message_definition_t* definition); @@ -132,7 +150,7 @@ extern "C" // Makes a copy of a given serialized message // Returns: The new copy - struct ps_msg_t ps_msg_cpy(const struct ps_msg_t* msg); + struct ps_msg_t ps_msg_cpy(const struct ps_msg_t* msg, struct ps_allocator_t* allocator); #ifdef __cplusplus } diff --git a/include/pubsub/Subscriber.h b/include/pubsub/Subscriber.h index 9b548b0..549bc58 100644 --- a/include/pubsub/Subscriber.h +++ b/include/pubsub/Subscriber.h @@ -31,19 +31,13 @@ struct ps_sub_t struct ps_allocator_t* allocator; - // used instead of a queue optionally ps_subscriber_fn_cb_t cb; + ps_subscriber_fn_cb_t cb_raw; void* cb_data; - unsigned int preferred_transport;// udp or tcp + int preferred_transport;// udp or tcp, or -1 for no preference unsigned int skip; - - // queue is implemented as a deque - int queue_start;// start index of items in the queue (loops around on positive side) - int queue_size;// maximum size of the queue - int queue_len;// current queue size - void** queue;// pointers to each of the queue items }; #pragma pack(push) @@ -58,10 +52,7 @@ struct ps_sub_req_header_t }; #pragma pack(pop) -void ps_sub_enqueue(struct ps_sub_t* sub, void* message, int data_size, const struct ps_msg_info_t* message_info); - -// if the subscriber was initialized with a type this returns decoded messages -void* ps_sub_deque(struct ps_sub_t* sub); +void ps_sub_receive(struct ps_sub_t* sub, void* encoded_message, int data_size, bool is_reference, const struct ps_msg_info_t* message_info); void ps_sub_destroy(struct ps_sub_t* sub); diff --git a/include/pubsub/TCPTransport.h b/include/pubsub/TCPTransport.h index 1338b22..596f6ac 100644 --- a/include/pubsub/TCPTransport.h +++ b/include/pubsub/TCPTransport.h @@ -1,6 +1,13 @@ #ifndef _PUBSUB_TCP_TRANSPORT_HEADER #define _PUBSUB_TCP_TRANSPORT_HEADER +#pragma once + +#ifdef __cplusplus +extern "C" +{ +#endif + // Must be included first on windows #include @@ -8,6 +15,7 @@ #include #include #include +#include //#include #include @@ -20,6 +28,12 @@ #define PUBSUB_TCP_TRANSPORT 1 +enum +{ + PS_TCP_PROTOCOL_DATA = PS_UDP_PROTOCOL_DATA, + PS_TCP_PROTOCOL_MESSAGE_DEFINITION = 0x03, +}; + /* typedef void(*ps_transport_fn_pub_t)(struct ps_transport_t* transport, struct ps_pub_t* publisher, void* message); @@ -62,8 +76,7 @@ struct ps_tcp_transport_connection struct ps_tcp_client_queued_message_t { - char* data; - int32_t length; + struct ps_msg_ref_t* msg; }; struct ps_tcp_client_t @@ -77,7 +90,7 @@ struct ps_tcp_client_t int32_t desired_packet_size; char* packet_data; - char* queued_message; + struct ps_msg_ref_t* queued_message; int32_t queued_message_length; int32_t queued_message_written; @@ -99,918 +112,12 @@ struct ps_tcp_transport_impl int num_connections; }; -void remove_client_socket(struct ps_tcp_transport_impl* transport, int socket, struct ps_node_t* node) -{ - // find the index - int i = 0; - for (; i < transport->num_clients; i++) - { - if (transport->clients[i].socket == socket)// socket packed in address - { - break; - } - } - -#ifdef _WIN32 - closesocket(socket); -#else - close(socket); -#endif - - if (transport->clients[i].packet_data) - { - free(transport->clients[i].packet_data); - } - - if (transport->clients[i].queued_message) - { - free(transport->clients[i].queued_message); - } - - // free queued messages - if (transport->clients[i].num_queued_messages) - { - for (int j = 0; j < transport->clients[i].num_queued_messages; j++) - { - free(transport->clients[i].queued_messages[j].data); - } - free(transport->clients[i].queued_messages); - } - - struct ps_tcp_client_t* old_clients = transport->clients; - transport->num_clients -= 1; - - // close the socket and dont wait on it anymore - ps_event_set_remove_socket(&node->events, transport->clients[i].socket); +void ps_tcp_transport_destroy(struct ps_transport_t* transport); - if (transport->num_clients) - { - transport->clients = (struct ps_tcp_client_t*)malloc(sizeof(struct ps_tcp_client_t) * transport->num_clients); - for (int j = 0; j < i; j++) - { - transport->clients[j] = old_clients[j]; - } +void ps_tcp_transport_init(struct ps_transport_t* transport, struct ps_node_t* node); - for (int j = i + 1; j <= transport->num_clients; j++) - { - transport->clients[j - 1] = old_clients[j]; - } - } - free(old_clients); -} - -void ps_tcp_remove_connection(struct ps_tcp_transport_impl* impl, int index) -{ - // Free our subscribers and any buffers - int iter = 0; - int new_size = impl->num_connections - 1; - struct ps_tcp_transport_connection* new_connections = new_size == 0 ? 0 : (struct ps_tcp_transport_connection*)malloc(sizeof(struct ps_tcp_transport_connection) * new_size); - for (int i = 0; i < impl->num_connections; i++) - { - if (i != index) - { - new_connections[iter++] = impl->connections[i]; - continue; - } - - if (!impl->connections[i].waiting_for_header) - { - free(impl->connections[i].packet_data); - } - ps_event_set_remove_socket(&impl->node->events, impl->connections[i].socket); -#ifdef _WIN32 - closesocket(impl->connections[i].socket); -#else - close(impl->connections[i].socket); -#endif - } - impl->num_connections = new_size; - free(impl->connections); - impl->connections = new_connections; +#ifdef __cplusplus } - -int ps_tcp_transport_spin(struct ps_transport_t* transport, struct ps_node_t* node) -{ - struct ps_tcp_transport_impl* impl = (struct ps_tcp_transport_impl*)transport->impl; - int socket = accept(impl->socket, 0, 0); - if (socket > 0) - { -#ifdef PUBSUB_VERBOSE - printf("Got new socket connection!\n"); -#endif - - // add it to the list yo - impl->num_clients++; - struct ps_tcp_client_t* old_sockets = impl->clients; - impl->clients = (struct ps_tcp_client_t*)malloc(sizeof(struct ps_tcp_client_t) * impl->num_clients); - for (int i = 0; i < impl->num_clients - 1; i++) - { - impl->clients[i] = old_sockets[i]; - } - struct ps_tcp_client_t* new_client = &impl->clients[impl->num_clients - 1]; - new_client->socket = socket; - new_client->socket = socket; - new_client->needs_removal = false; - new_client->current_packet_size = 0; - new_client->desired_packet_size = 0; - new_client->packet_data = 0; - new_client->queued_message = 0; - new_client->queued_message_length = 0; - new_client->queued_message_written = 0; - new_client->queued_messages = 0; - new_client->num_queued_messages = 0; - - // set non-blocking -#ifdef _WIN32 - DWORD nonBlocking = 1; - if (ioctlsocket(socket, FIONBIO, &nonBlocking) != 0) - { - printf("Failed to Set Socket as Non-Blocking!\n"); - closesocket(socket); - return 0; - } -#endif -#ifdef ARDUINO - fcntl(socket, F_SETFL, O_NONBLOCK); -#endif -#ifdef __unix__ - int flags = fcntl(socket, F_GETFL); - fcntl(socket, F_SETFL, flags | O_NONBLOCK); -#endif - - ps_event_set_add_socket(&node->events, socket); - - if (impl->num_clients - 1) - { - free(old_sockets); - } - } - //printf("polled\n"); - -// remove any old sockets - for (int i = 0; i < impl->num_clients; i++) - { - if (impl->clients[i].needs_removal) - { - // add it to the list - struct ps_client_t client; - client.endpoint.address = impl->clients[i].socket; - client.endpoint.port = 255;// p->port; - client.stream_id = 0; - ps_pub_remove_client(impl->clients[i].publisher, &client);// todo this is probably unsafe.... - - remove_client_socket(impl, impl->clients[i].socket, impl->clients[i].publisher->node); - - i = i - 1; - break; - } - } - - // update our sockets yo - for (int i = 0; i < impl->num_clients; i++) - { - struct ps_tcp_client_t* client = &impl->clients[i]; - - // send queued messages until we block or cant send anymore - while (client->queued_message != 0) - { - int to_send = client->queued_message_length - client->queued_message_written; - int sent = send(client->socket, &client->queued_message[client->queued_message_written], to_send, 0); - if (sent > 0) - { - client->queued_message_written += sent; - } -#ifdef WIN32 - else if (sent < 0 && WSAGetLastError() != WSAEWOULDBLOCK) -#else - else if (sent < 0 && errno != EAGAIN) -#endif - { - client->needs_removal = true; - } - - if (client->queued_message_written == client->queued_message_length) - { - //printf("Message sent.\n"); - free(client->queued_message); - client->queued_message = 0; - - // we finished! check if there are more to send - if (client->num_queued_messages > 0) - { - // grab a message from the front of our message queue - client->queued_message = client->queued_messages[0].data; - client->queued_message_written = 0; - client->queued_message_length = client->queued_messages[0].length; - - client->num_queued_messages -= 1; - if (client->num_queued_messages == 0) - { - free(client->queued_messages); - client->queued_messages = 0; - continue; - } - - struct ps_tcp_client_queued_message_t* msgs = (struct ps_tcp_client_queued_message_t*)malloc(client->num_queued_messages * sizeof(struct ps_tcp_client_queued_message_t)); - for (int i = 0; i < client->num_queued_messages; i++) - { - msgs[i] = client->queued_messages[i + 1];// take from the front - } - free(client->queued_messages); - client->queued_messages = msgs; - - // continue so we can attempt to send again - } - else - { - ps_event_set_remove_socket_write(&node->events, client->socket); - break;// no more to send - } - } - else - { - break;// we couldnt send anymore atm - } - } - - // check for new data and add it to the packet if present - char buf[1500]; - // if we havent gotten a header yet, just check for that - if (client->desired_packet_size == 0) - { - const int header_size = 5; - int len = recv(client->socket, buf, header_size, MSG_PEEK); - if (len == 0) - { - client->needs_removal = true; - continue; - } - if (len < header_size) - { - continue;// no header yet - } - - char message_type = buf[0];// not used atm - - // we actually got the header! start looking for the message - len = recv(client->socket, buf, header_size, 0); - //connection->packet_type = message_type; - //client->waiting_for_header = false; - client->desired_packet_size = *(uint32_t*)&buf[1]; - //printf("Incoming message with %i bytes\n", client->desired_packet_size); - client->packet_data = (char*)malloc(client->desired_packet_size); - - client->current_packet_size = 0; - } - // read in the message - if (client->desired_packet_size != 0) - { - int remaining_size = client->desired_packet_size - client->current_packet_size; - // check for new messages and read until we hit packet size - int len = recv(client->socket, &client->packet_data[client->current_packet_size], remaining_size, 0); - if (len > 0) - { - //printf("Read %i bytes of message\n", len); - client->current_packet_size += len; - - if (client->current_packet_size == client->desired_packet_size) - { -#ifdef PUBSUB_VERBOSE - printf("message finished\n"); -#endif - - if (true)// todo look at message id - { - // its a subscribe - const char* topic = &client->packet_data[4]; - // check if this matches any of our publishers - for (unsigned int pi = 0; pi < node->num_pubs; pi++) - { - struct ps_pub_t* pub = node->pubs[pi]; - if (strcmp(topic, pub->topic) == 0) - { - uint32_t skip = *(uint32_t*)&client->packet_data[0]; - // send response and start publishing - struct ps_client_t sub_client; - sub_client.endpoint.address = client->socket; - sub_client.endpoint.port = 255;// p->port; - sub_client.last_keepalive = 10000000000000;//GetTickCount64();// use the current time stamp - sub_client.sequence_number = 0; - sub_client.stream_id = 0; - sub_client.modulo = skip > 0 ? skip + 1 : 0; - sub_client.transport = transport; - - impl->clients[i].publisher = pub; - - // send the client the acknowledgement and message definition - int8_t packet_type = 0x03;//message definition - send(impl->clients[i].socket, (char*)&packet_type, 1, 0); - - char buf[1500]; - int32_t length = ps_serialize_message_definition((void*)buf, pub->message_definition); - send(impl->clients[i].socket, (char*)&length, 4, 0); - send(impl->clients[i].socket, buf, length, 0); - -#ifdef PUBSUB_VERBOSE - printf("TCPTransport: Got subscribe request, adding client if we haven't already\n"); #endif - ps_pub_add_client(pub, &sub_client); - - break; - } - } - } - - free(client->packet_data); - client->packet_data = 0; - client->desired_packet_size = 0; - } - } - } - } - - int message_count = 0; - for (int i = 0; i < impl->num_connections; i++) - { - struct ps_tcp_transport_connection* connection = &impl->connections[i]; - char buf[1500]; - if (connection->connecting) - { - //printf("checking for connected\n"); - // select to check for writability - fd_set wfds; - struct timeval tv; - int retval; - - FD_ZERO(&wfds); - FD_SET(connection->socket, &wfds); - - tv.tv_sec = 0; - tv.tv_usec = 0; - retval = select(connection->socket + 1, NULL, &wfds, NULL, &tv); - if (retval == -1) - { - // error? - printf("socket errored while connecting\n"); - } - else if (retval) - { - // socket is writable - //printf("socket writable\n"); - - // make the subscribe request in a "packet" - // a packet is an int length followed by data - int8_t packet_type = 0x01;//subscribe - send(connection->socket, (char*)&packet_type, 1, 0); - - int32_t length = strlen(connection->subscriber->topic) + 1 + 4; - send(connection->socket, (char*)&length, 4, 0); - - // make the request - char buffer[500]; - strcpy(buffer, connection->subscriber->topic); - uint32_t skip = connection->subscriber->skip; - send(connection->socket, (char*)&skip, 4, 0); - send(connection->socket, buffer, length - 4, 0); - - connection->connecting = false; - } - } - // if we havent gotten a header yet, just check for that - else if (connection->waiting_for_header) - { - const int header_size = 5; - int len = recv(connection->socket, buf, header_size, MSG_PEEK); - //printf("peek got: %i\n", len); - if (len == 0) - { - // we got disconnected - ps_tcp_remove_connection(impl, i); - i--; - continue; - } - else if (len < header_size) - { - continue;// no header yet - } - - char message_type = buf[0]; - - // we actually got the header! start looking for the message - len = recv(connection->socket, buf, header_size, 0); - connection->packet_type = message_type; - connection->waiting_for_header = false; - connection->packet_size = *(uint32_t*)&buf[1]; - //printf("Incoming message with %i bytes\n", impl->connections[i].packet_size); - connection->packet_data = (char*)malloc(connection->packet_size); - - connection->current_size = 0; - } - else // read in the message - { - int remaining_size = connection->packet_size - connection->current_size; - - // check for new messages and read until we hit packet size - int len = recv(connection->socket, &connection->packet_data[connection->current_size], remaining_size, 0); - if (len == 0) - { - // we got disconnected - ps_tcp_remove_connection(impl, i); - i--; - continue; - } - else if (len > 0) - { - //printf("Read %i bytes of message\n", len); - connection->current_size += len; - - if (connection->current_size == connection->packet_size) - { - //printf("message finished type %x\n", connection->packet_type); - if (connection->packet_type == 0x3) - { - //printf("Was message definition\n"); - if (connection->subscriber->type == 0) - { - // todo put this in a function so we cant accidentally forget it - if (connection->subscriber->received_message_def.fields == 0) - { - ps_deserialize_message_definition(connection->packet_data, &connection->subscriber->received_message_def); - } - - // call the callback as well - if (node->def_cb) - { - node->def_cb(&connection->subscriber->received_message_def, node->def_cb_data); - } - } - - free(connection->packet_data); - } - else if (connection->packet_type == 0x2) - { - //printf("added to queue\n"); - // decode and add it to the queue - struct ps_msg_info_t message_info; - message_info.address = connection->endpoint.address; - message_info.port = connection->endpoint.port; - - void* out_data; - if (connection->subscriber->type) - { - out_data = connection->subscriber->type->decode(connection->packet_data, connection->subscriber->allocator); - free(connection->packet_data); - } - else - { - out_data = connection->packet_data; - } - - // remove the reference to packet data so we dont try and double free it on destroy - // it is the queue's responsibility now - connection->packet_data = 0; - ps_sub_enqueue(connection->subscriber, - out_data, - connection->packet_size, - &message_info); - - message_count++; - } - else - { - // unhandled packet id - free(connection->packet_data); - } - connection->waiting_for_header = true; - } - } - } - } - return message_count; -} - -void ps_tcp_transport_pub(struct ps_transport_t* transport, struct ps_pub_t* publisher, struct ps_client_t* client, const void* message, uint32_t length) -{ - struct ps_tcp_transport_impl* impl = (struct ps_tcp_transport_impl*)transport->impl; - - // the client packs the socket id in the addr - int socket = client->endpoint.address; - - // okay, so new version, if any write fails (EAGAIN or < expected size) - // we make a copy of the entire message and store it on that client to try and send again in our update loop - // if we get into this function and this client already has a queued message, just drop this one - struct ps_tcp_client_t* tclient = 0; - for (int i = 0; i < impl->num_clients; i++) - { - if (impl->clients[i].socket == socket) - { - tclient = &impl->clients[i]; - break; - } - } - - if (tclient->queued_message != 0) - { - // check if we have queue space left - - // for now hardcode max queue size - const int max_queue_size = 10; - - // copy the message to put it in the queue - // todo remove this copy - char* data = (char*)malloc(length + 4 + 1); - data[0] = 0x02; - *((uint32_t*)&data[1]) = length; - memcpy(&data[5], ps_get_msg_start(message), length); - - // this if statement is unnecessary, but I added it for the sake of testing/completeness - if (tclient->queued_message == 0) - { - tclient->queued_message = data; - tclient->queued_message_length = length + 5; - tclient->queued_message_written = 0; - } - else if (tclient->num_queued_messages >= max_queue_size) - { - // todo use a deque lol - // swap everything down, freeing the first - for (int i = tclient->num_queued_messages - 1; i >= 1; i--) - { - tclient->queued_messages[i] = tclient->queued_messages[i - 1]; - } - tclient->queued_messages[0].data = data; - tclient->queued_messages[0].length = length + 5; - //printf("dropped message on topic '%s'\n", publisher->topic); - return;// drop it, we are out of queue space - } - else - { - //printf("queuing up message %i on topic '%s'\n", tclient->num_queued_messages, publisher->topic); - - // add the message to the front of the queue - tclient->num_queued_messages += 1; - struct ps_tcp_client_queued_message_t* msgs = (struct ps_tcp_client_queued_message_t*)malloc(tclient->num_queued_messages * sizeof(struct ps_tcp_client_queued_message_t)); - - msgs[0].data = data; - msgs[0].length = length + 5; - for (int i = 0; i < tclient->num_queued_messages - 1; i++) - { - msgs[i + 1] = tclient->queued_messages[i]; - } - free(tclient->queued_messages); - tclient->queued_messages = msgs; - - return; - } - } - //printf("started writing\n"); - // try and write, if any of these fail, make a copy - uint8_t packet_type = 0x02; - int c = send(socket, (char*)&packet_type, 1, 0); - if (c == 0) - { - tclient->queued_message_written = 0; - goto FAILCOPY; - } - if (c < 0) - { -#ifdef WIN32 - int error = WSAGetLastError(); - if (error == WSAEWOULDBLOCK) -#else - if (errno == EAGAIN || errno == EWOULDBLOCK) -#endif - { - tclient->queued_message_written = 0; - goto FAILCOPY; - } - goto FAILDISCONNECT; - } - - c = send(socket, (char*)&length, 4, 0); - if (c < 4 && c >= 0) - { - tclient->queued_message_written = c + 1; - goto FAILCOPY; - } - if (c < 0) - { -#ifdef WIN32 - int error = WSAGetLastError(); - if (error == WSAEWOULDBLOCK) -#else - if (errno == EAGAIN || errno == EWOULDBLOCK) -#endif - { - tclient->queued_message_written = 1; - goto FAILCOPY; - } - goto FAILDISCONNECT; - } - - //printf("sending %i bytes\n", length + 4 + 1); - c = send(socket, (char*)ps_get_msg_start(message), length, 0); - if (c < length && c >= 0) - { - tclient->queued_message_written = c + 5; - goto FAILCOPY; - } - if (c < 0) - { -#ifdef WIN32 - int error = WSAGetLastError(); - if (error == WSAEWOULDBLOCK) -#else - if (errno == EAGAIN || errno == EWOULDBLOCK) -#endif - { - tclient->queued_message_written = 5; - goto FAILCOPY; - } - goto FAILDISCONNECT; - } - //printf("wrote all\n"); - return; - - char* data; -FAILDISCONNECT: - //printf("Disconnected: %s\n", strerror(err)); - tclient->needs_removal = true; - return; - -FAILCOPY: - // todo remove this copy - data = (char*)malloc(length + 4 + 1); - data[0] = 0x02; - *((uint32_t*)&data[1]) = length; - memcpy(&data[5], ps_get_msg_start(message), length); - - tclient->queued_message = data; - tclient->queued_message_length = length + 5; - ps_event_set_add_socket_write(&publisher->node->events, socket); - return; -} - -void ps_tcp_transport_subscribe(struct ps_transport_t* transport, struct ps_sub_t* subscriber, struct ps_endpoint_t* ep, uint32_t transport_info) -{ - struct ps_tcp_transport_impl* impl = (struct ps_tcp_transport_impl*)transport->impl; - - // check if we already have a sub for this subscriber with this endpoint - // if so, ignore it - for (int i = 0; i < impl->num_connections; i++) - { - if (impl->connections[i].endpoint.port == ep->port && - impl->connections[i].endpoint.address == ep->address && - impl->connections[i].subscriber->sub_id == subscriber->sub_id) - { - return; - } - } - -#ifdef _WIN32 - SOCKET sock = socket(AF_INET, SOCK_STREAM, 0); -#else - int sock = socket(AF_INET, SOCK_STREAM, 0); -#endif - - // set non-blocking -#ifdef _WIN32 - DWORD nonBlocking = 1; - if (ioctlsocket(sock, FIONBIO, &nonBlocking) != 0) - { - ps_print_socket_error("Failed to Set Socket as Non-Blocking"); - closesocket(sock); - return; - } -#endif -#ifdef ARDUINO - fcntl(sock, F_SETFL, O_NONBLOCK); -#endif -#ifdef __unix__ - int flags = fcntl(sock, F_GETFL); - fcntl(sock, F_SETFL, flags | O_NONBLOCK); -#endif - - // Actually connect - //printf("connecting\n"); - struct sockaddr_in server_addr; - server_addr.sin_family = AF_INET; - server_addr.sin_addr.s_addr = htonl(ep->address); - server_addr.sin_port = htons(transport_info); - int connect_result = connect(sock, (struct sockaddr*)&server_addr, sizeof(server_addr)); - if (connect_result != 0) - { -#ifdef _WIN32 - if (WSAGetLastError() != WSAEWOULDBLOCK)//WSAEINPROGRESS -#else - if (errno != EINPROGRESS) -#endif - { - ps_print_socket_error("error connecting tcp socket"); - return; - } - } - - //printf("%i %i %i %i\n", (ep->address & 0xFF000000) >> 24, (ep->address & 0xFF0000) >> 16, (ep->address & 0xFF00) >> 8, (ep->address & 0xFF)); - - // make the subscribe request in a "packet" - // a packet is an int length followed by data - /*int8_t packet_type = 0x01;//subscribe - send(sock, (char*)&packet_type, 1, 0); - - int32_t length = strlen(subscriber->topic) + 1 + 4; - send(sock, (char*)&length, 4, 0); - - // make the request - char buffer[500]; - strcpy(buffer, subscriber->topic); - uint32_t skip = subscriber->skip; - send(sock, (char*)&skip, 4, 0); - send(sock, buffer, length - 4, 0);*/ - - // add the socket to the list of connections - impl->num_connections++; - struct ps_tcp_transport_connection* old_connections = impl->connections; - impl->connections = (struct ps_tcp_transport_connection*)malloc(sizeof(struct ps_tcp_transport_connection) * impl->num_connections); - for (int i = 0; i < impl->num_connections - 1; i++) - { - impl->connections[i] = old_connections[i]; - } - - struct ps_tcp_transport_connection* new_connection = &impl->connections[impl->num_connections - 1]; - new_connection->socket = sock; - new_connection->endpoint = *ep; - new_connection->waiting_for_header = true; - new_connection->subscriber = subscriber; - new_connection->connecting = true; - - ps_event_set_add_socket(&subscriber->node->events, sock); - - if (impl->num_connections - 1) - { - free(old_connections); - } -} - -void ps_tcp_transport_unsubscribe(struct ps_transport_t* transport, struct ps_sub_t* subscriber) -{ - struct ps_tcp_transport_impl* impl = (struct ps_tcp_transport_impl*)transport->impl; - - // remove all transports which reference this subscriber - int num_to_remove = 0; - for (int i = 0; i < impl->num_connections; i++) - { - if (impl->connections[i].subscriber == subscriber) - { - num_to_remove++; - } - } - -#ifdef PUBSUB_VERBOSE - printf("Removing %i tcp subs\n", num_to_remove); -#endif - - if (num_to_remove > 0) - { - // Free our subscribers and any buffers - int iter = 0; - int new_size = impl->num_connections - num_to_remove; - struct ps_tcp_transport_connection* new_connections = new_size == 0 ? 0 : (struct ps_tcp_transport_connection*)malloc(sizeof(struct ps_tcp_transport_connection) * new_size); - for (int i = 0; i < impl->num_connections; i++) - { - if (impl->connections[i].subscriber != subscriber) - { - new_connections[iter++] = impl->connections[i]; - continue; - } - - if (!impl->connections[i].waiting_for_header) - { - free(impl->connections[i].packet_data); - } - ps_event_set_remove_socket(&impl->node->events, impl->connections[i].socket); -#ifdef _WIN32 - closesocket(impl->connections[i].socket); -#else - close(impl->connections[i].socket); -#endif - } - impl->num_connections = new_size; - free(impl->connections); - impl->connections = new_connections; - } -} - - -void ps_tcp_transport_destroy(struct ps_transport_t* transport) -{ - struct ps_tcp_transport_impl* impl = (struct ps_tcp_transport_impl*)transport->impl; - - // Free our subscribers and any buffers - for (int i = 0; i < impl->num_connections; i++) - { - if (!impl->connections[i].waiting_for_header) - { - free(impl->connections[i].packet_data); - } - ps_event_set_remove_socket(&impl->node->events, impl->connections[i].socket); -#ifdef _WIN32 - closesocket(impl->connections[i].socket); -#else - close(impl->connections[i].socket); -#endif - } - - for (int i = 0; i < impl->num_clients; i++) - { - ps_event_set_remove_socket(&impl->node->events, impl->clients[i].socket); -#ifdef _WIN32 - closesocket(impl->clients[i].socket); -#else - close(impl->clients[i].socket); -#endif - } - -#ifdef _WIN32 - closesocket(impl->socket); -#else - close(impl->socket); -#endif - - if (impl->num_clients) - { - free(impl->clients); - } - - if (impl->num_connections) - { - free(impl->connections); - } - - free(impl); -} - -void ps_tcp_transport_init(struct ps_transport_t* transport, struct ps_node_t* node) -{ -#ifdef __unix__ - signal(SIGPIPE, SIG_IGN); -#endif - - transport->spin = ps_tcp_transport_spin; - transport->subscribe = ps_tcp_transport_subscribe; - transport->unsubscribe = ps_tcp_transport_unsubscribe; - transport->destroy = ps_tcp_transport_destroy; - transport->pub = ps_tcp_transport_pub; - transport->uuid = 1; - - struct ps_tcp_transport_impl* impl = (struct ps_tcp_transport_impl*)malloc(sizeof(struct ps_tcp_transport_impl)); - - impl->num_clients = 0; - impl->num_connections = 0; - - impl->node = node; - - impl->socket = socket(AF_INET, SOCK_STREAM, 0); - - struct sockaddr_in server_addr; - server_addr.sin_family = AF_INET; - server_addr.sin_addr.s_addr = INADDR_ANY; - server_addr.sin_port = 0;// we want an ephemeral port - if (bind(impl->socket, (struct sockaddr*)&server_addr, sizeof(server_addr)) < 0) - { - ps_print_socket_error("error binding tcp transport socket"); - } - - socklen_t outlen = sizeof(struct sockaddr_in); - struct sockaddr_in outaddr; - getsockname(impl->socket, (struct sockaddr*)&outaddr, &outlen); - transport->transport_info = ntohs(outaddr.sin_port); - - printf("Bound tcp to %i\n", transport->transport_info); - - // set non-blocking -#ifdef _WIN32 - DWORD nonBlocking = 1; - if (ioctlsocket(impl->socket, FIONBIO, &nonBlocking) != 0) - { - ps_print_socket_error("Failed to Set Socket as Non-Blocking"); - closesocket(impl->socket); - return; - } -#endif -#ifdef ARDUINO - fcntl(impl->socket, F_SETFL, O_NONBLOCK); -#endif -#ifdef __unix__ - int flags = fcntl(impl->socket, F_GETFL); - fcntl(impl->socket, F_SETFL, flags | O_NONBLOCK); -#endif - - listen(impl->socket, 5); - - ps_event_set_add_socket(&node->events, impl->socket); - - transport->impl = (void*)impl; -} #endif diff --git a/include/pubsub/UDPTransport.h b/include/pubsub/UDPTransport.h index cf32d88..ac27fb9 100644 --- a/include/pubsub/UDPTransport.h +++ b/include/pubsub/UDPTransport.h @@ -17,12 +17,13 @@ struct ps_endpoint_t; struct ps_sub_t; struct ps_pub_t; struct ps_msg_t; +struct ps_msg_ref_t; struct ps_client_t; void ps_udp_subscribe(struct ps_sub_t* sub, const struct ps_endpoint_t* ep); void ps_udp_unsubscribe(struct ps_sub_t* sub); -void ps_udp_publish(struct ps_pub_t* pub, struct ps_client_t* client, struct ps_msg_t* msg); +void ps_udp_publish(struct ps_pub_t* pub, struct ps_client_t* client, struct ps_msg_ref_t* msg); #endif diff --git a/include/pubsub_cpp/Node.h b/include/pubsub_cpp/Node.h index 257330c..092ee3d 100644 --- a/include/pubsub_cpp/Node.h +++ b/include/pubsub_cpp/Node.h @@ -3,8 +3,9 @@ #include #include #include +#include #include - +#include #include #include @@ -26,22 +27,22 @@ namespace pubsub { -static std::map _remappings; +extern std::map _remappings; // how intraprocess passing works class SubscriberBase; class PublisherBase; // this mutex protects both of the below -static std::mutex _publisher_mutex; -static std::multimap _publishers; -static std::multimap _subscribers; +extern std::mutex _publisher_mutex; +extern std::multimap _publishers; +extern std::multimap _subscribers; // this assumes topic and ns are properly checked // ns should not have a leading slash, topic should if it is absolute inline std::string handle_remap(const std::string& topic, const std::string& ns) { - //printf("Handling remap of %s in ns %s\n", topic.c_str(), ns.c_str()); + //printf("Handling remap of %s in ns %s\n", topic.c_str(), ns.c_str()); // we need at least one character if (topic.length() == 0) { @@ -87,10 +88,14 @@ inline std::string handle_remap(const std::string& topic, const std::string& ns) } // okay, we had no remappings, use our namespace - if (ns.length()) - return "/" + ns + "/" + topic; - else - return "/" + topic; + if (ns.length()) + { + return "/" + ns + "/" + topic; + } + else + { + return "/" + topic; + } } // not thread safe @@ -119,7 +124,7 @@ inline void initialize(const char** args, const int argc) // valid names must be all lowercase and only inline std::string validate_name(const std::string& name, bool remove_leading_slashes = false) { - //printf("Validating %s\n", name.c_str()); + //printf("Validating %s\n", name.c_str()); for (size_t i = 0; i < name.length(); i++) { if (name[i] >= 'A' && name[i] <= 'Z') @@ -138,8 +143,8 @@ inline std::string validate_name(const std::string& name, bool remove_leading_sl return name.substr(i); } - else - { + else + { // remove any duplicate slashes we may have std::string out; if (name.length()) @@ -155,16 +160,78 @@ inline std::string validate_name(const std::string& name, bool remove_leading_sl out += name[i]; } return out; - } + } return name; } +struct ParameterContainer +{ + std::mutex lock; + double d; + float f; + int i; + std::string s; + + operator double() const + { + return d; + } + + operator float() const + { + return f; + } + + operator int() const + { + return i; + } + + operator std::string() const + { + return s; + } +}; + +class Node; +template +class Parameter +{ + friend class Node; + std::shared_ptr value; +public: + + operator T() const + { + std::lock_guard lock(value->lock); + return (T)*value; + } + + T get() const + { + std::lock_guard lock(value->lock); + return (T)*value; + } + + /*void operator=(const T& new_value) + { + // todo update parameter value + value->d = new_value; + }*/ +}; + // todo need to make sure multiple subscribers in a process share data // safe to use each node in a different thread after initialize is called // making calls to functions on the same node is not thread safe +typedef std::unique_ptr SubscriberPtr; +typedef std::unique_ptr PublisherPtr; class SubscriberBase; class Spinner; +template +class Subscriber; +template +class Publisher; class Node { friend class Spinner; @@ -206,17 +273,23 @@ class Node ~Node() { + if (params_data_) + { + ps_destroy_parameters(params_data_.get()); + } ps_node_destroy(&node_); } std::string getQualifiedName() { - if (namespace_.length()) - return "/" + namespace_ + "/" + real_name_; - else - { - return "/" + real_name_; - } + if (namespace_.length()) + { + return "/" + namespace_ + "/" + real_name_; + } + else + { + return "/" + real_name_; + } } const std::string& getName() @@ -239,15 +312,63 @@ class Node return ps_node_spin(&node_); } - inline void setEventSet(ps_event_set_t* set) - { - event_set_ = set; - } + inline void setEventSet(ps_event_set_t* set) + { + event_set_ = set; + } inline ps_event_set_t* getEventSet() { return event_set_; } + + template + SubscriberBase* subscribe(const std::string& topic, std::function&)> cb, unsigned int queue_size = 1, int preferred_transport = -1, int skip = 0) + { + return new Subscriber(*this, topic, cb, queue_size, preferred_transport, skip); + } + + template + Publisher* advertise(const std::string& topic, bool latched = false, int preferred_transport = -1) + { + return new Publisher(*this, topic, latched, preferred_transport); + } + + std::unique_ptr params_data_; + std::map> params_; + + Parameter parameter(const std::string& name, double default_value, const std::string& desc = "", + double min = -10000, double max = 10000) + { + if (!params_data_) + { + params_data_.reset(new ps_parameters); + ps_create_parameters(getNode(), params_data_.get(), [](const char* name, double value, void* data) + { + auto tthis = (Node*)data; + auto res = tthis->params_.find(name); + if (res == tthis->params_.end()) + { + return; + } + + auto shrd = res->second.lock(); + if (shrd) + { + std::lock_guard lock(shrd->lock); + shrd->d = value; + } + }, this); + } + + ps_add_parameter_double(params_data_.get(), name.c_str(), desc.c_str(), default_value, min, max); + + Parameter p; + p.value = std::make_shared(); + p.value->d = default_value; + params_[name] = p.value; + return p; + } // mark that we have a message to process void mark() { marked_ = true; } @@ -271,27 +392,35 @@ class PublisherBase std::vector subs_; public: + + virtual ~PublisherBase() {} + const std::string& GetTopic() { return remapped_topic_; } + ps_pub_t* GetPub() + { + return &publisher_; + } + Node* GetNode() { return node_; } }; -template -class Subscriber; template class Publisher: public PublisherBase { std::shared_ptr latched_msg_; public: friend class Subscriber; + + typedef std::unique_ptr> Ptr; - Publisher(Node& node, const std::string& topic, bool latched = false)// : topic_(topic) + Publisher(Node& node, const std::string& topic, bool latched = false, int preferred_transport = 0)// : topic_(topic) { node_ = &node; topic_ = topic; @@ -304,21 +433,21 @@ class Publisher: public PublisherBase remapped_topic_ = handle_remap(real_topic, node.getNamespace()); node.lock_.lock(); - ps_node_create_publisher(node.getNode(), remapped_topic_.c_str(), T::GetDefinition(), &publisher_, latched); + ps_node_create_publisher_ex(node.getNode(), remapped_topic_.c_str(), T::GetDefinition(), &publisher_, latched, preferred_transport, T::Allocator::allocator()); node.lock_.unlock(); //add me to the publisher list _publisher_mutex.lock(); _publishers.insert(std::pair(remapped_topic_, this)); - // look for any matching subscribers and add them to our list - auto iterpair = _subscribers.equal_range(topic); - for (auto it = iterpair.first; it != iterpair.second; ++it) - { - node.lock_.lock(); - subs_.push_back(it->second); - node.lock_.unlock(); - } + // look for any matching subscribers and add them to our list + auto iterpair = _subscribers.equal_range(topic); + for (auto it = iterpair.first; it != iterpair.second; ++it) + { + node.lock_.lock(); + subs_.push_back(it->second); + node.lock_.unlock(); + } _publisher_mutex.unlock(); } @@ -367,11 +496,9 @@ class Publisher: public PublisherBase // loop through shared subscribers node_->lock_.lock(); // now go through my local subscriber list - for (size_t i = 0; i < subs_.size(); i++) + for (auto& sub: subs_) { - auto sub = subs_[i]; - - //printf("Publishing locally with no copy..\n"); + printf("Publishing locally with no copy..\n"); auto specific_sub = (Subscriber*)sub; ps_event_set_trigger(specific_sub->node_->getEventSet()); @@ -394,7 +521,8 @@ class Publisher: public PublisherBase void publish(const T& msg) { std::shared_ptr copy; - if (latched_) { + if (latched_) + { copy = std::shared_ptr(new T); *copy = msg; // save for later @@ -402,11 +530,9 @@ class Publisher: public PublisherBase } node_->lock_.lock(); // now go through my local subscriber list - for (size_t i = 0; i < subs_.size(); i++) + for (auto& sub: subs_) { - auto sub = subs_[i]; - - //printf("Publishing locally with a copy..\n"); + printf("Publishing locally with a copy..\n"); if (!copy) { //copy to shared ptr @@ -436,7 +562,7 @@ class Publisher: public PublisherBase unsigned int getNumSubscribers() { - return ps_pub_get_subscriber_count(&publisher_); + return ps_pub_get_subscriber_count(&publisher_) + subs_.size(); } void addCustomEndpoint(const int ip_addr, const short port, const unsigned int stream_id) @@ -472,7 +598,6 @@ class SubscriberBase // if its latched, get the message from it if (it->second->latched_) { - // hmm, this should just queue not call cb(it->second); } // add me to its sub list @@ -481,7 +606,7 @@ class SubscriberBase } } - _subscribers.insert(std::pair(topic, sub)); + _subscribers.insert(std::pair(topic, sub)); _publisher_mutex.unlock(); } @@ -499,26 +624,30 @@ class SubscriberBase // remove me from its list if im there auto pos = std::find(it->second->subs_.begin(), it->second->subs_.end(), sub); if (pos != it->second->subs_.end()) + { it->second->subs_.erase(pos); + } it->second->GetNode()->lock_.unlock(); } } - //remove me from the subscriber list - auto subiterpair = _subscribers.equal_range(topic); - for (auto it = subiterpair.first; it != subiterpair.second; ++it) - { - if (it->second == sub) - { - _subscribers.erase(it); - break; - } - } + //remove me from the subscriber list + auto subiterpair = _subscribers.equal_range(topic); + for (auto it = subiterpair.first; it != subiterpair.second; ++it) + { + if (it->second == sub) + { + _subscribers.erase(it); + break; + } + } _publisher_mutex.unlock(); } public: + virtual ~SubscriberBase() {} + ps_sub_t* GetSub() { return &subscriber_; @@ -528,8 +657,6 @@ class SubscriberBase virtual bool CallOne() = 0; }; - - template class Subscriber: public SubscriberBase { @@ -545,7 +672,9 @@ class Subscriber: public SubscriberBase public: - Subscriber(Node& node, const std::string& topic, std::function&)> cb, unsigned int queue_size = 1, int preferred_transport = 0) : cb_(cb), queue_size_(queue_size) + typedef std::unique_ptr> Ptr; + + Subscriber(Node& node, const std::string& topic, std::function&)> cb, unsigned int queue_size = 1, int preferred_transport = -1, int skip = 0) : cb_(cb), queue_size_(queue_size) { node_ = &node; @@ -571,15 +700,13 @@ class Subscriber: public SubscriberBase struct ps_subscriber_options options; ps_subscriber_options_init(&options); - - options.queue_size = 0; + options.skip = skip; options.cb = cb2; options.cb_data = this; - options.allocator = 0; + options.allocator = T::Allocator::allocator(); options.ignore_local = true; options.preferred_transport = preferred_transport; - node.lock_.lock(); ps_node_create_subscriber_adv(node.getNode(), remapped_topic_.c_str(), T::GetDefinition(), &subscriber_, &options); node.subscribers_.push_back(this); @@ -615,6 +742,34 @@ class Subscriber: public SubscriberBase return false; } + std::shared_ptr PopOne() + { + queue_mutex_.lock(); + if (!queue_.size()) + { + queue_mutex_.unlock(); + return {}; + } + auto back = queue_.back(); + queue_.pop_back(); + queue_mutex_.unlock(); + return back; + } + + void PushOne(const std::shared_ptr& msg) + { + auto specific_sub = this; + //ps_event_set_trigger(specific_sub->node_->getEventSet()); + specific_sub->queue_mutex_.lock(); + specific_sub->queue_.push_front(msg); + if (specific_sub->queue_.size() > specific_sub->queue_size_) + { + specific_sub->queue_.pop_back(); + } + specific_sub->queue_mutex_.unlock(); + //specific_sub->node_->mark(); + } + ~Subscriber() { close(); @@ -632,18 +787,15 @@ class Subscriber: public SubscriberBase node_->lock_.lock(); auto it = std::find(node_->subscribers_.begin(), node_->subscribers_.end(), this); if (it != node_->subscribers_.end()) + { node_->subscribers_.erase(it); + } ps_sub_destroy(&subscriber_); node_->lock_.unlock(); node_ = 0; } - T* deque() - { - return (T*)ps_sub_deque(&subscriber_); - } - const std::string& getQualifiedTopic() { return remapped_topic_; diff --git a/include/pubsub_cpp/Spinners.h b/include/pubsub_cpp/Spinners.h index b07bb2d..bafccd6 100644 --- a/include/pubsub_cpp/Spinners.h +++ b/include/pubsub_cpp/Spinners.h @@ -123,6 +123,7 @@ class BlockingSpinnerWithTimers BlockingSpinnerWithTimers(int num_threads = 1) : running_(false), node_(0) { //ps_event_set_create(&events_); + stop(true); } ~BlockingSpinnerWithTimers() @@ -138,7 +139,7 @@ class BlockingSpinnerWithTimers void setNode(Node& node) { - node_ = &node; + node_ = &node; //list_mutex_.lock(); // build a wait list for all nodes @@ -154,13 +155,12 @@ class BlockingSpinnerWithTimers { while (running_ && ps_okay()) { - //printf("Entering thread\n"); list_mutex_.lock(); if (node_ == 0)//ps_event_set_count(&events_) == 0) { - printf("Waiting for events\n"); + printf("Waiting for events\n"); ps_sleep(10); - continue; + continue; } else { @@ -184,15 +184,15 @@ class BlockingSpinnerWithTimers // this line is necessary anyways, but happens to work around the above bug timeout = std::min(timeout, 1000000);// make sure we dont block too long - //printf("setting timeout to %i\n", timeout); + //printf("setting timeout to %i\n", timeout); } list_mutex_.unlock(); - //ps_node_wait(node_->getNode(), 0); + //ps_node_wait(node_->getNode(), 0); // todo allow finer grained waits //if (timeout > 2) { // allows for fine grained waits - ps_event_set_set_timer(&node_->getNode()->events, timeout);// in us + ps_event_set_set_timer(&node_->getNode()->events, timeout);// in us ps_node_wait(node_->getNode(), 1000); //ps_event_set_wait(&events_, timeout); } @@ -205,7 +205,7 @@ class BlockingSpinnerWithTimers Time now = Time::now(); if (now >= timer.next_trigger) { - //printf("Calling timer\n"); + //printf("Calling timer\n"); timer.next_trigger = timer.next_trigger + timer.period; timer.fn(); } @@ -213,7 +213,7 @@ class BlockingSpinnerWithTimers // check all nodes //for (auto node : nodes_) - if (node_) + if (node_) { node_->lock_.lock(); if (ps_node_spin(node_->getNode()) || node_->marked()) @@ -249,7 +249,12 @@ class BlockingSpinnerWithTimers void wait() { - if (!running_) + thread_.join(); + } + + void run() + { + if (!running_) { start(); } @@ -265,8 +270,10 @@ class BlockingSpinnerWithTimers } running_ = false; - if (join && thread_.joinable()) - thread_.join();// wait for it to stop + if (join && thread_.joinable()) + { + thread_.join();// wait for it to stop + } } }; @@ -301,7 +308,7 @@ class Spinner int res = 0; if (res = ps_node_spin(node->getNode()) || node->marked()) { - printf("Received %i messages\n", res); + printf("Received %i messages\n", res); // we got a message, now call a subscriber // todo how to make this not scale with subscriber count... for (size_t i = 0; i < node->subscribers_.size(); i++) @@ -311,7 +318,6 @@ class Spinner } } - node->lock_.unlock(); } list_mutex_.unlock(); diff --git a/include/pubsub_cpp/Time.h b/include/pubsub_cpp/Time.h index bf07cf7..8ba695a 100644 --- a/include/pubsub_cpp/Time.h +++ b/include/pubsub_cpp/Time.h @@ -38,17 +38,40 @@ class Duration } - bool operator<(const Duration& rhs) const// otherwise, both parameters may be const references + bool operator==(const Duration& rhs) const + { + return this->usec == rhs.usec; + } + + bool operator!=(const Duration& rhs) const + { + return this->usec != rhs.usec; + } + + bool operator<(const Duration& rhs) const { return this->usec < rhs.usec; } - bool operator>(const Duration& rhs) const// otherwise, both parameters may be const references + bool operator>(const Duration& rhs) const { return this->usec > rhs.usec; } + + Duration operator+(const Duration& rhs) const + { + Duration out; + out.usec = this->usec + rhs.usec; + return out; + } + + Duration& operator+=(const Duration& rhs) + { + this->usec += rhs.usec; + return *this; + } - double toSec() + double toSec() const { return usec / 1000000.0; } @@ -77,38 +100,54 @@ class Time } - Duration operator-(const Time& rhs) // otherwise, both parameters may be const references + Duration operator-(const Time& rhs) const { Duration out; out.usec = this->usec - rhs.usec; - return out; // return the result by value (uses move constructor) + return out; + } + + bool operator==(const Time& rhs) const + { + return this->usec == rhs.usec; + } + + bool operator!=(const Time& rhs) const + { + return this->usec != rhs.usec; } - bool operator<(const Time& rhs) const// otherwise, both parameters may be const references + bool operator<(const Time& rhs) const { return this->usec < rhs.usec; } - bool operator<=(const Time& rhs) const// otherwise, both parameters may be const references + bool operator<=(const Time& rhs) const { return this->usec <= rhs.usec; } - bool operator>(const Time& rhs) const// otherwise, both parameters may be const references + bool operator>(const Time& rhs) const { return this->usec > rhs.usec; } - bool operator>=(const Time& rhs) const// otherwise, both parameters may be const references + bool operator>=(const Time& rhs) const { return this->usec >= rhs.usec; } - Time operator+(const Duration& rhs) // otherwise, both parameters may be const references + Time operator+(const Duration& rhs) const { Time out; out.usec = this->usec + rhs.usec; - return out; // return the result by value (uses move constructor) + return out; + } + + Time& operator+=(const Duration& rhs) + { + this->usec += rhs.usec; + return *this; } static Time now() @@ -147,12 +186,12 @@ class Time #endif } - double toSec() + double toSec() const { return usec / 1000000.0; } - std::string toString() + std::string toString() const { time_t t = usec / 1000000;// toSec(); diff --git a/include/pubsub_cpp/allocator.h b/include/pubsub_cpp/allocator.h new file mode 100644 index 0000000..bae8328 --- /dev/null +++ b/include/pubsub_cpp/allocator.h @@ -0,0 +1,13 @@ +#pragma once + +namespace pubsub +{ +struct DefaultAllocator +{ + static ps_allocator_t* allocator() + { + return &ps_default_allocator; + } +}; + +} diff --git a/include/pubsub_cpp/array_string.h b/include/pubsub_cpp/array_string.h new file mode 100644 index 0000000..81dc79e --- /dev/null +++ b/include/pubsub_cpp/array_string.h @@ -0,0 +1,209 @@ + +#pragma once + +#include +#include +#include +#include + +namespace pubsub +{ + +#pragma pack(push, 1) +// Wrapper for a fixed size array for using it like a string +template +class FixedString +{ + char data_[string_length]; +public: + + FixedString() + { + data_[0] = 0; + } + + operator const char*() + { + return data_; + } + + void operator=(const std::string& string) + { + if (string.length() >= string_length) + { + throw std::runtime_error("Too big."); + } + strcpy(data_, string.c_str()); + } + + void operator=(const char* string) + { + // todo make this more efficient? + if (strlen(string) >= string_length) + { + throw std::runtime_error("Too big."); + } + strncpy(data_, string, string_length - 1); + } + + bool operator==(const char* other) const + { + return strcmp(other, data_) == 0; + } + + bool operator==(const std::string& other) const + { + return strcmp(other.c_str(), data_) == 0; + } + + bool operator==(const FixedString& other) const + { + return strcmp(other.c_str(), data_) == 0; + } + + char* data() const + { + return data_; + } + + const char* c_str() const + { + return data_; + } + + int max_size() const + { + return string_length; + } + + int length() const + { + return strlen(data_); + } +}; + +// Wrapper for a C string which makes it easier to use and handles freeing +template +class CString +{ + char* data_; + + // copy a buffer of a given length to a new allocated array + static char* copy(const char* obj, uint32_t length) + { + auto data = (char*)Allocator::allocator()->alloc(length, Allocator::allocator()->context); + for (int i = 0; i < length; i++) + { + data[i] = obj[i]; + } + return data; + } + + static char* copy(const char* obj) + { + auto length = strlen(obj) + 1; + auto data = (char*)Allocator::allocator()->alloc(length, Allocator::allocator()->context); + for (int i = 0; i < length; i++) + { + data[i] = obj[i]; + } + return data; + } +public: + + CString() + { + data_ = 0; + } + + CString(const CString& other) + { + data_ = 0; + if (other.data_) + { + data_ = copy(other.data_); + } + } + + ~CString() + { + if (data_) + { + Allocator::allocator()->free(data_, Allocator::allocator()->context); + } + } + + void operator=(const std::string& string) + { + if (data_) + { + Allocator::allocator()->free(data_, Allocator::allocator()->context); + } + data_ = copy(string.c_str(), string.length()+1); + } + + void operator=(const char* string) + { + if (data_) + { + Allocator::allocator()->free(data_, Allocator::allocator()->context); + } + data_ = copy(string); + } + + void operator=(const CString& other) + { + if (data_) + { + Allocator::allocator()->free(data_, Allocator::allocator()->context); + data_ = 0; + } + if (other.data_) + { + data_ = copy(other.data_); + } + } + + bool operator==(const char* other) const + { + if (data_ == 0) + { + return strcmp(other, "") == 0; + } + return strcmp(other, data_) == 0; + } + + bool operator==(const std::string& other) const + { + if (data_ == 0) + { + return strcmp(other.c_str(), "") == 0; + } + return strcmp(other.c_str(), data_) == 0; + } + + bool operator==(const CString& other) const + { + if (data_ == 0) + { + return strcmp(other.c_str(), "") == 0; + } + return strcmp(other.c_str(), data_) == 0; + } + + char* data() const + { + return data_; + } + + const char* c_str() const + { + if (data_ == 0) + { + return ""; + } + return data_; + } +}; +#pragma pack(pop) +} diff --git a/include/pubsub_cpp/array_vector.h b/include/pubsub_cpp/array_vector.h index 84e5e03..3961783 100644 --- a/include/pubsub_cpp/array_vector.h +++ b/include/pubsub_cpp/array_vector.h @@ -2,15 +2,30 @@ #pragma once #include +#include +namespace pubsub +{ #pragma pack(push, 1) // Vector like array that's able to take ownership of C malloced arrays -template +template class ArrayVector { T* data_; uint32_t length_; + + // copy a buffer of a given length to a new allocated array + static T* copy(const T* obj, uint32_t length) + { + auto data = (T*)Allocator::allocator()->alloc(sizeof(T)*length, Allocator::allocator()->context); + for (int i = 0; i < length; i++) + { + data[i] = obj[i]; + } + return data; + } + public: ArrayVector() @@ -28,34 +43,26 @@ class ArrayVector ArrayVector(const ArrayVector& obj) { length_ = obj.length_; - data_ = (T*)malloc(sizeof(T)*length_); - for (int i = 0; i < length_; i++) - { - data_[i] = obj[i]; - } + data_ = copy(obj.data(), length_); } ~ArrayVector() { if (data_) { - free(data_); + Allocator::allocator()->free(data_, Allocator::allocator()->context); } } - ArrayVector& operator=(const ArrayVector& obj) + ArrayVector& operator=(const ArrayVector& arr) { if (data_) { - free(data_); + Allocator::allocator()->free(data_, Allocator::allocator()->context); } - length_ = obj.length_; - data_ = (T*)malloc(sizeof(T)*length_); - for (int i = 0; i < length_; i++) - { - data_[i] = obj[i]; - } + length_ = arr.length_; + data_ = copy(arr.data(), length_); return *this; } @@ -63,15 +70,11 @@ class ArrayVector { if (data_) { - free(data_); + Allocator::allocator()->free(data_, Allocator::allocator()->context); } length_ = arr.size(); - data_ = (T*)malloc(sizeof(T)*length_); - for (int i = 0; i < length_; i++) - { - data_[i] = arr[i]; - } + data_ = copy(arr.data(), length_); return *this; } @@ -80,24 +83,33 @@ class ArrayVector { if (size == length_) { return; } - auto new_data = (T*)malloc(sizeof(T)*size); + auto new_data = (T*)Allocator::allocator()->alloc(sizeof(T)*size, Allocator::allocator()->context); auto copy_len = std::min(size, length_); - for (int i = 0; i < copy_len; i++) + for (uint32_t i = 0; i < copy_len; i++) { new_data[i] = data_[i]; } length_ = size; if (data_) { - free(data_); + Allocator::allocator()->free(data_, Allocator::allocator()->context); } data_ = new_data; } + + // reliquinquishes the held pointer without freeing + T* reset() + { + auto out = data_; + data_ = 0; + length_ = 0; + return out; + } void clear() { length_ = 0; - free(data_); + Allocator::allocator()->free(data_, Allocator::allocator()->context); data_ = 0; } @@ -118,3 +130,4 @@ class ArrayVector inline const_iterator end() const { return &data_[length_]; } }; #pragma pack(pop) +} diff --git a/msg/Int.msg b/msg/Int.msg new file mode 100644 index 0000000..18ba850 --- /dev/null +++ b/msg/Int.msg @@ -0,0 +1 @@ +int64 value diff --git a/msg/Parameters.msg b/msg/Parameters.msg index 991b929..27fdf97 100644 --- a/msg/Parameters.msg +++ b/msg/Parameters.msg @@ -11,9 +11,8 @@ enum uint8 type[] # The current value for the parameter string value[] -# Inclusive maximum valid value for the parameter +# Inclusive minimum valid value for the parameter double min[] -# Inclusive minimum valid value for the parameter -# Not valid for max +# Inclusive maximum valid value for the parameter double max[] diff --git a/src/Bindings.c b/src/Bindings.c index 7c37145..bd41d27 100644 --- a/src/Bindings.c +++ b/src/Bindings.c @@ -161,7 +161,7 @@ EXPORT int ps_create_publisher(int node, const char* topic, const char* definiti struct ps_msg_field_t* field = &fields[num_fields-1]; field->name = name; field->length = 1; - field->content_length = 0; + field->string_length = 0; field->type = 0;// filled in below field->flags = 0; if (strcmp(type, "int8") == 0) @@ -226,7 +226,7 @@ EXPORT int ps_create_publisher(int node, const char* topic, const char* definiti EXPORT void ps_publish(int pub, const void* msg, int len) { - // publish the message simply since it is already encoded + // publish the message simply since it is already encoded struct ps_msg_t omsg; ps_msg_alloc(len, 0, &omsg); memcpy(ps_get_msg_start(omsg.data), msg, len); diff --git a/src/Events.c b/src/Events.c index cdced01..321876e 100644 --- a/src/Events.c +++ b/src/Events.c @@ -1,6 +1,7 @@ #include #include +#include #ifndef WIN32 #include @@ -14,12 +15,23 @@ void ps_event_set_create(struct ps_event_set_t* set) { #ifdef WIN32 - set->num_handles = 1; + // Old version which had an event per socket + /*set->num_handles = 1; set->handles = (HANDLE*)malloc(sizeof(HANDLE)); set->sockets = (int*)malloc(sizeof(int)); + set->handles[0] = WSACreateEvent(); + set->sockets[0] = -1;*/ + + // New version which uses one event for all sockets + set->num_handles = 3; + set->handles = (HANDLE*)malloc(sizeof(HANDLE)*3); + set->sockets = (int*)malloc(sizeof(int)); + set->handles[0] = WSACreateEvent(); set->sockets[0] = -1; + set->handles[1] = WSACreateEvent(); + set->handles[2] = CreateWaitableTimer(NULL, TRUE, NULL); #else set->fd = epoll_create(1); set->num_events = 0; @@ -53,8 +65,10 @@ void ps_event_set_destroy(struct ps_event_set_t* set) void ps_event_set_add_socket(struct ps_event_set_t* set, int socket) { + //printf("add socket read %i\n", socket); #ifdef WIN32 - // allocate a new spot + // Old version which creates an event per socket + /*// allocate a new spot int cur_size = set->num_handles; HANDLE* new_handles = (HANDLE*)malloc(sizeof(HANDLE)*(cur_size + 1)); int* new_sockets = (int*)malloc(sizeof(int)*(cur_size + 1)); @@ -72,7 +86,10 @@ void ps_event_set_add_socket(struct ps_event_set_t* set, int socket) set->handles[cur_size + 0] = WSACreateEvent(); set->sockets[cur_size + 0] = socket; - WSAEventSelect(socket, set->handles[cur_size + 0], FD_READ); + WSAEventSelect(socket, set->handles[cur_size + 0], FD_READ);*/ + + // New version which always uses the same event + WSAEventSelect(socket, set->handles[1], FD_READ); #else struct epoll_event event; event.events = EPOLLIN; @@ -84,8 +101,10 @@ void ps_event_set_add_socket(struct ps_event_set_t* set, int socket) void ps_event_set_add_socket_write(struct ps_event_set_t* set, int socket) { + //printf("add socket read write %i\n", socket); #ifdef WIN32 - // find the handle and change the select + // Old version which creates an event per socket + /*// find the handle and change the select for (unsigned int i = 0; i < set->num_handles; i++) { if (set->sockets[i] == socket) @@ -93,7 +112,10 @@ void ps_event_set_add_socket_write(struct ps_event_set_t* set, int socket) WSAEventSelect(socket, set->handles[i], FD_READ | FD_WRITE); break; } - } + }*/ + + // New version which always uses the same event + WSAEventSelect(socket, set->handles[1], FD_READ | FD_WRITE); #else struct epoll_event event; event.events = EPOLLIN | EPOLLOUT; @@ -104,7 +126,21 @@ void ps_event_set_add_socket_write(struct ps_event_set_t* set, int socket) void ps_event_set_add_socket_write_only(struct ps_event_set_t* set, int socket) { + //printf("add socket write only %i\n", socket); #ifdef WIN32 + // Old version which creates an event per socket + /*// find the handle and change the select + for (int i = 0; i < set->num_handles; i++) + { + if (set->sockets[i] == socket) + { + WSAEventSelect(socket, set->handles[i], FD_READ | FD_WRITE); + break; + } + }*/ + + // New version which always uses the same event + WSAEventSelect(socket, set->handles[1], FD_WRITE); #else struct epoll_event event; event.events = EPOLLOUT; @@ -115,8 +151,10 @@ void ps_event_set_add_socket_write_only(struct ps_event_set_t* set, int socket) void ps_event_set_remove_socket_write(struct ps_event_set_t* set, int socket) { + //printf("remove socket write %i\n", socket); #ifdef WIN32 - // find the handle and change the select + // Old version which creates an event per socket + /*// find the handle and change the select for (unsigned int i = 0; i < set->num_handles; i++) { if (set->sockets[i] == socket) @@ -124,7 +162,10 @@ void ps_event_set_remove_socket_write(struct ps_event_set_t* set, int socket) WSAEventSelect(socket, set->handles[i], FD_READ); break; } - } + }*/ + + // New version which always uses the same event + WSAEventSelect(socket, set->handles[1], FD_READ); #else struct epoll_event event; event.events = EPOLLIN; @@ -136,8 +177,10 @@ void ps_event_set_remove_socket_write(struct ps_event_set_t* set, int socket) void ps_event_set_remove_socket(struct ps_event_set_t* set, int socket) { + //printf("remove socket %i\n", socket); #ifdef WIN32 - // find the socket to remove then remove it + // Old version which creates an event per socket + /*// find the socket to remove then remove it bool found = false; unsigned int index = 0; for (; index < set->num_handles; index++) @@ -179,8 +222,10 @@ void ps_event_set_remove_socket(struct ps_event_set_t* set, int socket) free(set->handles); free(set->sockets); set->handles = new_handles; - set->sockets = new_sockets; + set->sockets = new_sockets;*/ + // New version which always uses the same event + WSAEventSelect(socket, set->handles[1], 0); #else struct epoll_event event; event.events = EPOLLIN; @@ -206,7 +251,7 @@ void ps_event_set_trigger(struct ps_event_set_t* set) unsigned int ps_event_set_count(const struct ps_event_set_t* set) { #ifdef WIN32 - return set->num_handles-1; + return set->num_handles-1;// this is wrong, but oh well #else return set->num_events; #endif @@ -242,7 +287,19 @@ void ps_event_set_wait(struct ps_event_set_t* set, unsigned int timeout_ms) void ps_event_set_set_timer(struct ps_event_set_t* set, unsigned int timeout_us) { #ifdef WIN32 - // todo + // todo test + // use CreateWaitableTimer and SetWaitableTimer + WaitForMultipleObjectsEx + + LARGE_INTEGER due; + due.QuadPart = 10*timeout_us;// in 100ns intervals + due.QuadPart = -due.QuadPart;// negative = relative + + if (!SetWaitableTimer(set->handles[2], &due, 0, NULL, NULL, 0)) + { + printf("CreateWaitableTimer failed (%d)\n", GetLastError()); + return; + } + #else if (set->timer_fd == 0) { diff --git a/src/Node.c b/src/Node.c index 30af011..defbf90 100644 --- a/src/Node.c +++ b/src/Node.c @@ -17,6 +17,13 @@ #include #endif +#ifndef _WIN32 +#include +#include +#include +#include +#endif + // sends out a system query message for all nodes to advertise void ps_node_system_query(struct ps_node_t* node) { @@ -114,6 +121,7 @@ void ps_node_advertise(struct ps_pub_t* pub) p->addr = pub->node->addr; p->port = pub->node->port; p->flags = pub->latched ? PS_ADVERTISE_LATCHED : 0; + p->flags |= (pub->recommended_transport << 1) & 0b11111; p->type_hash = pub->message_definition->hash; p->transports = pub->node->supported_transports; p->group_id = pub->node->group_id; @@ -134,8 +142,7 @@ void ps_node_advertise(struct ps_pub_t* pub) int sent_bytes = sendto(pub->node->socket, (const char*)data, off, 0, (struct sockaddr*)&address, sizeof(struct sockaddr_in)); } - -void ps_node_create_publisher(struct ps_node_t* node, const char* topic, const struct ps_message_definition_t* type, struct ps_pub_t* pub, bool latched) +void ps_node_create_publisher_ex(struct ps_node_t* node, const char* topic, const struct ps_message_definition_t* type, struct ps_pub_t* pub, bool latched, unsigned int recommended_transport, struct ps_allocator_t* allocator) { node->num_pubs++; struct ps_pub_t** old_pubs = node->pubs; @@ -154,13 +161,19 @@ void ps_node_create_publisher(struct ps_node_t* node, const char* topic, const s pub->topic = topic; pub->node = node; pub->latched = latched; - pub->last_message.data = 0; - pub->last_message.len = 0; + pub->last_message = 0; pub->sequence_number = 0; + pub->recommended_transport = recommended_transport; + pub->allocator = allocator ? allocator : &ps_default_allocator; ps_node_advertise(pub); } +void ps_node_create_publisher(struct ps_node_t* node, const char* topic, const struct ps_message_definition_t* type, struct ps_pub_t* pub, bool latched) +{ + ps_node_create_publisher_ex(node, topic, type, pub, latched, 0, 0); +} + // Setup Control-C handlers #ifdef _WIN32 static int ps_shutdown = 0; @@ -293,10 +306,11 @@ void ps_node_init_ex(struct ps_node_t* node, const char* name, const char* ip, b { ip = GetPrimaryIp(); } + uint32_t our_address = inet_addr(ip); printf("Pubsub IP: %s\n", ip); #ifdef _WIN32 - node->group_id = GetCurrentProcessId() + (10000 * ((inet_addr(ip) >> 24) && 0xFF)); + node->group_id = GetCurrentProcessId() + (10000 * ((our_address >> 24) && 0xFF)); #else node->group_id = 0;// ignore the group #endif @@ -311,9 +325,35 @@ void ps_node_init_ex(struct ps_node_t* node, const char* name, const char* ip, b } else if (broadcast) { - //convert to a broadcast address (just the subnet wide one) - node->advertise_addr = inet_addr(ip); + //convert to a broadcast address (just the subnet wide one) + node->advertise_addr = our_address; + //okay, for this to work we need the subnet address we're assuming and its sometimes wrong node->advertise_addr |= 0xFF000000; +#ifndef _WIN32 + struct ifaddrs *ifap, *ifa; + struct sockaddr_in *sa; + + getifaddrs(&ifap); + for (ifa = ifap; ifa; ifa = ifa->ifa_next) + { + if (ifa->ifa_addr && ifa->ifa_addr->sa_family == AF_INET) + { + sa = (struct sockaddr_in*)ifa->ifa_addr; + if (sa->sin_addr.s_addr == our_address) + { + sa = (struct sockaddr_in*)ifa->ifa_ifu.ifu_broadaddr; + node->advertise_addr = sa->sin_addr.s_addr; + printf("found\n"); + } + } + } + + freeifaddrs(ifap); +#endif + // print the result + struct in_addr ip_addr; + ip_addr.s_addr = node->advertise_addr; + printf("Broadcast Address: %s\n", inet_ntoa(ip_addr)); } else { @@ -325,7 +365,7 @@ void ps_node_init_ex(struct ps_node_t* node, const char* name, const char* ip, b // Setup the core socket node->socket = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP); - node->addr = ntohl(inet_addr(ip)); + node->addr = ntohl(our_address); if (node->socket == 0) { printf("Failed To Create Socket!\n"); @@ -508,7 +548,7 @@ void* ps_malloc_alloc(unsigned int size, void* _) return malloc(size); } -void ps_malloc_free(void* data) +void ps_malloc_free(void* data, void* _) { free(data); } @@ -517,13 +557,13 @@ struct ps_allocator_t ps_default_allocator = { ps_malloc_alloc, ps_malloc_free, void ps_subscriber_options_init(struct ps_subscriber_options* options) { - options->queue_size = 1; options->ignore_local = false; options->allocator = 0; options->skip = 0; options->cb = 0; + options->cb_raw = 0; options->cb_data = 0; - options->preferred_transport = PS_TRANSPORT_UDP; + options->preferred_transport = -1;// no preference } void ps_node_create_subscriber_adv(struct ps_node_t* node, const char* topic, const struct ps_message_definition_t* type, @@ -559,57 +599,15 @@ void ps_node_create_subscriber_adv(struct ps_node_t* node, const char* topic, co sub->received_message_def.hash = 0; sub->received_message_def.num_fields = 0; - // force queue size to be > 0 - unsigned int queue_size = options->queue_size; - if (options->cb) - { - sub->cb = options->cb; - sub->cb_data = options->cb_data; - sub->queue_size = 0; - sub->queue_len = 0; - sub->queue_start = 0; - sub->queue = 0; - } - else - { - if (queue_size <= 0) - { - queue_size = 1; - } - - // allocate queue data - sub->queue_len = 0; - sub->queue_start = 0; - sub->queue_size = queue_size; - sub->queue = (void**)malloc(sizeof(void*) * queue_size); - - for (unsigned int i = 0; i < queue_size; i++) - { - sub->queue[i] = 0; - } - } + sub->cb = options->cb; + sub->cb_raw = options->cb_raw; + sub->cb_data = options->cb_data; // send out the subscription query while we are at it ps_node_subscribe_query(sub); } void ps_node_create_subscriber(struct ps_node_t* node, const char* topic, const struct ps_message_definition_t* type, - struct ps_sub_t* sub, - unsigned int queue_size, - struct ps_allocator_t* allocator, - bool ignore_local) -{ - struct ps_subscriber_options options; - ps_subscriber_options_init(&options); - - options.queue_size = queue_size; - options.allocator = allocator; - options.ignore_local = ignore_local; - - ps_node_create_subscriber_adv(node, topic, type, sub, &options); -} - -void ps_node_create_subscriber_cb(struct ps_node_t* node, const char* topic, const struct ps_message_definition_t* type, struct ps_sub_t* sub, ps_subscriber_fn_cb_t cb, void* cb_data, @@ -620,7 +618,6 @@ void ps_node_create_subscriber_cb(struct ps_node_t* node, const char* topic, con struct ps_subscriber_options options; ps_subscriber_options_init(&options); - options.queue_size = 0; options.cb = cb; options.cb_data = cb_data; options.allocator = allocator; @@ -772,9 +769,10 @@ int ps_node_spin(struct ps_node_t* node) socklen_t fromLength = sizeof(from); int received_bytes = recvfrom(node->socket, (char*)data, size, 0, (struct sockaddr*)&from, &fromLength); - if (received_bytes <= 0) + { break; + } #ifdef PUBSUB_VERBOSE //printf("got transport packet\n"); @@ -817,22 +815,7 @@ int ps_node_spin(struct ps_node_t* node) // queue up the data, and copy :/ (can make zero copy for arduino version) int data_size = received_bytes - sizeof(struct ps_msg_header); - // also todo fastpath for PoD message types - - // okay, if we have the message definition, deserialize and output in a message - void* out_data; - if (sub->type) - { -//theres a leak if you use this and the queue fills up with complex types - out_data = sub->type->decode(data + sizeof(struct ps_msg_header), sub->allocator); - } - else - { - out_data = sub->allocator->alloc(data_size, sub->allocator->context); - memcpy(out_data, data + sizeof(struct ps_msg_header), data_size); - } - - ps_sub_enqueue(sub, out_data, data_size, &message_info); + ps_sub_receive(sub, data + sizeof(struct ps_msg_header), data_size, true, &message_info); #ifdef PUBSUB_VERBOSE //printf("Got message, queue len %i\n", sub->queue_len); @@ -993,9 +976,10 @@ int ps_node_spin(struct ps_node_t* node) socklen_t fromLength = sizeof(from); int received_bytes = recvfrom(node->mc_socket, (char*)data, size, 0, (struct sockaddr*)&from, &fromLength); - if (received_bytes <= 0) + { break; + } //printf("Got discovery msg \n"); @@ -1150,14 +1134,33 @@ int ps_node_spin(struct ps_node_t* node) ep.address = p->addr; ep.port = p->port; + // 0-31 + int recommended_transport = (p->flags >> 1) & 0b11111; + + int preferred_transport = PS_TRANSPORT_UDP; + if (sub->preferred_transport >= 0) + { + preferred_transport = sub->preferred_transport; + } + else + { + preferred_transport = recommended_transport; + } + //printf("preferred transport: %i\n", preferred_transport); + //printf("recommended transport: %i\n", recommended_transport); + + if (preferred_transport != 0) + { + preferred_transport = (1 << (preferred_transport-1)); + } // first match udp if its what we want or all that is offered - if (sub->preferred_transport == PS_TRANSPORT_UDP || p->transports == PS_TRANSPORT_UDP) + if (preferred_transport == PS_TRANSPORT_UDP || p->transports == PS_TRANSPORT_UDP) { ps_udp_subscribe(sub, &ep); } else if (node->num_transports == 0) { - printf("ERROR: Transport mismatch. Do not have desired transport.\n"); + printf("ERROR: Transport mismatch on topic '%s'. Do not have desired transport %i.\n", topic, preferred_transport); } else { @@ -1166,14 +1169,14 @@ int ps_node_spin(struct ps_node_t* node) for (int i = 0; i < node->num_transports; i++) { struct ps_transport_t* transport = &node->transports[i]; - if ((transport->uuid & sub->preferred_transport) != 0) + if ((transport->uuid & preferred_transport) != 0) { int data_index = 0; for (int i = 0; i < 16; i++) { if ((p->transports & (1 << i)) != 0) { - if (sub->preferred_transport == (1 << i)) + if (preferred_transport == (1 << i)) { // this is it break; @@ -1191,6 +1194,7 @@ int ps_node_spin(struct ps_node_t* node) if (!found) { // Otherwise fallback to udp + //printf("Match not found, falling back to udp\n"); ps_udp_subscribe(sub, &ep); } } @@ -1226,13 +1230,9 @@ int ps_node_spin(struct ps_node_t* node) else if (data[0] == PS_DISCOVERY_PROTOCOL_UNSUBSCRIBE) { //printf("Got unsubscribe request\n"); + struct ps_unsubscribe_req_t* msg = (struct ps_unsubscribe_req_t*)data; - int* addr = (int*)&data[1]; - unsigned short* port = (unsigned short*)&data[5]; - - unsigned int* stream_id = (unsigned int*)&data[7]; - - char* topic = (char*)&data[11]; + char* topic = (char*)&data[sizeof(struct ps_unsubscribe_req_t)]; //check if we have a sub matching that topic struct ps_pub_t* pub = 0; @@ -1254,9 +1254,9 @@ int ps_node_spin(struct ps_node_t* node) // remove the client struct ps_client_t client; - client.endpoint.address = *addr; - client.endpoint.port = *port; - client.stream_id = *stream_id; + client.endpoint.address = msg->addr; + client.endpoint.port = msg->port; + client.stream_id = msg->stream_id; ps_pub_remove_client(pub, &client); } else if (data[0] == PS_DISCOVERY_PROTOCOL_QUERY_ALL) @@ -1374,7 +1374,6 @@ void ps_node_set_parameter(struct ps_node_t* node, const char* name, double valu data[0] = PS_UDP_PROTOCOL_PARAM_CHANGE; *(double*)&data[1] = value; - int off = sizeof(struct ps_advertise_req_t); int len = serialize_string(&data[1+8], name) + 9; //also add other info... diff --git a/src/Parameter.c b/src/Parameter.c new file mode 100644 index 0000000..4d5e760 --- /dev/null +++ b/src/Parameter.c @@ -0,0 +1,129 @@ +#include + +#include +#include + +#include +#include +#include +#include + +//#include + +#include + +static double param_internal_callback(const char* name, double value, void* data) +{ + struct ps_parameters* params = (struct ps_parameters*)data; + for (int i = 0; i < params->msg.name_length; i++) + { + if (strcmp(name, params->msg.name[i]) == 0) + { + // enforce min/max + if (value < params->msg.min[i]) + { + value = params->msg.min[i]; + } + else if (value > params->msg.max[i]) + { + value = params->msg.max[i]; + } + free(params->msg.value[i]); + char valstr[50]; + sprintf(valstr, "%lf", value); + params->msg.value[i] = strdup(valstr); + params->callback(name, value, params->cb_data); + ps_pub_publish_ez(¶ms->param_pub, ¶ms->msg); + return value; + } + } + printf("could not find parameter %s\n", name); + + return nan(""); +} + +void ps_create_parameters(struct ps_node_t* node, struct ps_parameters* params_out, ps_param_fancy_cb_t callback, void* data) +{ + node->param_cb = param_internal_callback; + node->param_cb_data = (void*)params_out; + params_out->callback = callback; + params_out->cb_data = data; + params_out->msg.name_length = 0; + ps_node_create_publisher(node, "/parameters", &pubsub__Parameters_def, ¶ms_out->param_pub, true); +} + +void ps_destroy_parameters(struct ps_parameters* params) +{ + params->param_pub.node->param_cb = 0; + if (params->msg.name_length > 0) + { + // free everything! + for (int i = 0; i < params->msg.name_length; i++) + { + free(params->msg.name[i]); + free(params->msg.value[i]); + free(params->msg.description[i]); + } + free(params->msg.name); + free(params->msg.value); + free(params->msg.description); + free(params->msg.type); + free(params->msg.min); + free(params->msg.max); + } + ps_pub_destroy(¶ms->param_pub); +} +/* { FT_String, FF_NONE, "name", 0, 0 }, + { FT_String, FF_NONE, "description", 0, 0 }, + { FT_UInt8, FF_ENUM, "type", 0, 0 }, + { FT_String, FF_NONE, "value", 0, 0 }, + { FT_Float64, FF_NONE, "min", 0, 0 }, + { FT_Float64, FF_NONE, "max", 0, 0 }, */ +void ps_add_parameter_double(struct ps_parameters* params, + const char* name, const char* description, + double value, double min, double max) +{ + int old_len = params->msg.name_length; + int new_len = ++params->msg.name_length; + params->msg.name_length = params->msg.value_length = params->msg.min_length = + params->msg.max_length = params->msg.description_length = params->msg.type_length = new_len; + char** newname = (char**)malloc(sizeof(char*)*new_len); + char** newdesc = (char**)malloc(sizeof(char*)*new_len); + char** newvalue = (char**)malloc(sizeof(char*)*new_len); + double* newmin = (double*)malloc(sizeof(double)*new_len); + double* newmax = (double*)malloc(sizeof(double)*new_len); + uint8_t* newtype = (uint8_t*)malloc(sizeof(uint8_t)*new_len); + for (int i = 0; i < old_len; i++) + { + newname[i] = params->msg.name[i]; + newdesc[i] = params->msg.description[i]; + newvalue[i] = params->msg.value[i]; + newtype[i] = params->msg.type[i]; + newmin[i] = params->msg.min[i]; + newmax[i] = params->msg.max[i]; + } + newname[old_len] = strdup(name); + newdesc[old_len] = strdup(description); + char valstr[50]; + sprintf(valstr, "%lf", value); + newvalue[old_len] = strdup(valstr); + newtype[old_len] = PARAMETERS_DOUBLE; + newmin[old_len] = min; + newmax[old_len] = max; + if (old_len != 0) + { + free(params->msg.name); + free(params->msg.value); + free(params->msg.description); + free(params->msg.type); + free(params->msg.min); + free(params->msg.max); + } + params->msg.name = newname; + params->msg.description = newdesc; + params->msg.value = newvalue; + params->msg.type = newtype; + params->msg.min = newmin; + params->msg.max = newmax; + ps_pub_publish_ez(¶ms->param_pub, ¶ms->msg); +} diff --git a/src/Publisher.c b/src/Publisher.c index 93799e2..09f338d 100644 --- a/src/Publisher.c +++ b/src/Publisher.c @@ -10,10 +10,10 @@ #include -void ps_pub_publish_client(struct ps_pub_t* pub, struct ps_client_t* client, struct ps_msg_t* msg) +static void ps_pub_publish_client(struct ps_pub_t* pub, struct ps_client_t* client, struct ps_msg_ref_t* msg, bool force_publish) { // Skip messages if desired by the client - if (client->modulo > 0) + if (client->modulo > 0 && force_publish == false) { if (pub->sequence_number % client->modulo != 0) { @@ -21,12 +21,12 @@ void ps_pub_publish_client(struct ps_pub_t* pub, struct ps_client_t* client, str } } - if (client->transport) - { - //printf("publishing to custom transport\n"); - client->transport->pub(client->transport, pub, client, msg->data, msg->len); - return; - } + if (client->transport) + { + //printf("publishing to custom transport\n"); + client->transport->pub(client->transport, pub, client, msg); + return; + } // Send it via UDP transport ps_udp_publish(pub, client, msg); @@ -63,15 +63,19 @@ bool ps_pub_add_client(struct ps_pub_t* pub, const struct ps_client_t* client) pub->clients[i] = old_clients[i]; } pub->clients[pub->num_clients - 1] = *client; - + if (old_clients) + { + free(old_clients); + } + // todo this is probably the wrong spot for this // If we are latched, send the new client our last message - if (pub->last_message.data && pub->latched) + if (pub->last_message && pub->latched) { - //printf("publishing latched\n"); - ps_pub_publish_client(pub, &pub->clients[pub->num_clients - 1], &pub->last_message); + //printf("publishing latched\n"); + ps_pub_publish_client(pub, &pub->clients[pub->num_clients - 1], pub->last_message, true); } - return true; + return true; } void ps_pub_add_endpoint_client(struct ps_pub_t* pub, const struct ps_endpoint_t* endpoint, const unsigned int stream_id) @@ -126,13 +130,17 @@ void ps_pub_remove_client(struct ps_pub_t* pub, const struct ps_client_t* client pub->clients[pos++] = old_clients[i]; } } + if (old_clients) + { + free(old_clients); + } } void ps_pub_publish_ez(struct ps_pub_t* pub, void* msg) { if (pub->num_clients > 0 || pub->latched) { - struct ps_msg_t data = pub->message_definition->encode(0, msg); + struct ps_msg_t data = pub->message_definition->encode(msg, pub->allocator); ps_pub_publish(pub, &data); } @@ -141,26 +149,46 @@ void ps_pub_publish_ez(struct ps_pub_t* pub, void* msg) void ps_pub_publish(struct ps_pub_t* pub, struct ps_msg_t* msg) { pub->sequence_number++; + + // exit early if not latched and no clients + if (pub->num_clients == 0 && !pub->latched) + { + free(msg->data);// todo allocator + return; + } + // transfer it to a reference + struct ps_msg_ref_t* ref = (struct ps_msg_ref_t*)malloc(sizeof(struct ps_msg_ref_t)); + ref->len = msg->len; + ref->data = msg->data; + ref->refcount = 1; + + // fill out the header + struct ps_msg_header* hdr = (struct ps_msg_header*)msg->data; + hdr->pid = PS_UDP_PROTOCOL_DATA; + hdr->length = msg->len; + hdr->seq = pub->sequence_number; + hdr->id = 0; + for (unsigned int i = 0; i < pub->num_clients; i++) { struct ps_client_t* client = &pub->clients[i]; - ps_pub_publish_client(pub, client, msg); + ps_pub_publish_client(pub, client, ref, false); } if (pub->latched) { - if (pub->last_message.data) + if (pub->last_message) { //free the old and add the new - free(pub->last_message.data);// todo use allocator + ps_msg_ref_free(pub->last_message, pub->allocator); } - pub->last_message = *msg; + pub->last_message = ref; } else { - free(msg->data);// todo use allocator + ps_msg_ref_free(ref, pub->allocator); } } @@ -176,9 +204,9 @@ void ps_pub_destroy(struct ps_pub_t* pub) //remove it from the node's list of pubs pub->node->num_pubs--; struct ps_pub_t** old_pubs = pub->node->pubs; - if (pub->node->num_pubs) - { - pub->node->pubs = (struct ps_pub_t**)malloc(sizeof(struct ps_pub_t*)*pub->node->num_pubs); + if (pub->node->num_pubs) + { + pub->node->pubs = (struct ps_pub_t**)malloc(sizeof(struct ps_pub_t*)*pub->node->num_pubs); int ind = 0; for (unsigned int i = 0; i < pub->node->num_pubs+1; i++) { @@ -191,17 +219,17 @@ void ps_pub_destroy(struct ps_pub_t* pub) pub->node->pubs[ind++] = old_pubs[i]; } } - } - else - { - pub->node->pubs = 0; - } + } + else + { + pub->node->pubs = 0; + } free(old_pubs); - // free my latched message - if (pub->last_message.data) + // free my latched message + if (pub->last_message) { - free(pub->last_message.data);// todo use allocator + ps_msg_ref_free(pub->last_message, pub->allocator); } pub->clients = 0; diff --git a/src/Serialization.c b/src/Serialization.c index 3b604fe..ba0ce56 100644 --- a/src/Serialization.c +++ b/src/Serialization.c @@ -20,7 +20,7 @@ struct field { uint8_t type; uint32_t length; - uint16_t content_length;//if we are an array + uint16_t content_length;//context dependent char name[50]; //string goes here }; @@ -33,8 +33,6 @@ struct enumeration }; #pragma pack(pop) - - int ps_serialize_message_definition(void* start, const struct ps_message_definition_t* definition) { //ok, write out number of fields @@ -53,7 +51,7 @@ int ps_serialize_message_definition(void* start, const struct ps_message_definit struct field* f = (struct field*)cur; f->type = definition->fields[i].type | (definition->fields[i].flags << 5); f->length = definition->fields[i].length; - f->content_length = definition->fields[i].content_length; + f->content_length = definition->fields[i].struct_index; strcpy(f->name, definition->fields[i].name); cur += 1 + 4 + 2 + strlen(definition->fields[i].name)+ 1; @@ -78,6 +76,9 @@ int ps_serialize_message_definition(void* start, const struct ps_message_definit void ps_copy_message_definition(struct ps_message_definition_t* dst, const struct ps_message_definition_t* src) { + dst->encode = src->encode; + dst->decode = src->decode; + dst->free = src->free; dst->num_fields = src->num_fields; dst->hash = src->hash; dst->fields = (struct ps_msg_field_t*)malloc(sizeof(struct ps_msg_field_t)*dst->num_fields); @@ -89,7 +90,7 @@ void ps_copy_message_definition(struct ps_message_definition_t* dst, const struc dst->fields[i].type = src->fields[i].type; dst->fields[i].flags = src->fields[i].flags; dst->fields[i].length = src->fields[i].length; - dst->fields[i].content_length = src->fields[i].content_length; + dst->fields[i].struct_index = src->fields[i].struct_index; char* name = (char*)malloc(strlen(src->fields[i].name) + 1); strcpy(name, src->fields[i].name); dst->fields[i].name = name; @@ -115,6 +116,7 @@ void ps_deserialize_message_definition(const void * start, struct ps_message_def definition->num_enums = hdr->num_enums; definition->decode = 0; definition->encode = 0; + definition->free = 0; definition->name = 0; definition->fields = (struct ps_msg_field_t*)malloc(sizeof(struct ps_msg_field_t)*definition->num_fields); @@ -133,7 +135,7 @@ void ps_deserialize_message_definition(const void * start, struct ps_message_def definition->fields[i].type = (ps_field_types)f->type & 0x1F; definition->fields[i].flags = f->type >> 5; definition->fields[i].length = f->length; - definition->fields[i].content_length = f->content_length; + definition->fields[i].struct_index = f->content_length; //need to allocate the name int len = strlen(f->name); char* field_name = (char*)malloc(len + 1); @@ -175,8 +177,9 @@ void ps_free_message_definition(struct ps_message_definition_t * definition) free(definition->enums); } -static int GetFieldSize(int type) +static int GetFieldSize(const struct ps_msg_field_t* field) { + int type = field->type; int field_size = 0; if (type == FT_Int8 || type == FT_UInt8) { @@ -194,10 +197,14 @@ static int GetFieldSize(int type) { field_size = 8; } + else if (type == FT_ArrayString) + { + field_size = field->string_length; + } return field_size; } -struct ps_deserialize_iterator ps_deserialize_start(const char* msg, const struct ps_message_definition_t* definition) +struct ps_deserialize_iterator ps_deserialize_start(const void* msg, const struct ps_message_definition_t* definition) { struct ps_deserialize_iterator iter; iter.next_field_index = 0; @@ -208,7 +215,7 @@ struct ps_deserialize_iterator ps_deserialize_start(const char* msg, const struc } // takes in a deserialize iterator and returns pointer to the data and the current field -const char* ps_deserialize_iterate(struct ps_deserialize_iterator* iter, const struct ps_msg_field_t** f, uint32_t* l) +const void* ps_deserialize_iterate(struct ps_deserialize_iterator* iter, const struct ps_msg_field_t** f, uint32_t* l) { if (iter->next_field_index == iter->num_fields) { @@ -248,6 +255,12 @@ const char* ps_deserialize_iterate(struct ps_deserialize_iterator* iter, const s } } } + else if (field->type == FT_StructDefinition) + { + // skip past it + iter->next_field_index += field->length; + return ps_deserialize_iterate(iter, f, l); + } else if (field->type == FT_Struct) { // lets just treat this as a normal element, and leave it to the iterator to handle this @@ -255,16 +268,17 @@ const char* ps_deserialize_iterate(struct ps_deserialize_iterator* iter, const s // okay, lets just not allow struct arrays in arrays for now? // TODO, i dont think this is even implemented in code gen + const struct ps_msg_field_t* struct_field = &iter->fields[field->struct_index]; + iter->struct_num_fields = struct_field->length; + // calculate element *width* uint32_t width = 0; - for (int i = 0; i < field->content_length; i++) + for (int i = 0; i < iter->struct_num_fields; i++) { - const struct ps_msg_field_t* member = &iter->fields[iter->next_field_index++]; - width += GetFieldSize(member->type); + const struct ps_msg_field_t* member = ++struct_field; + width += GetFieldSize(member)*member->length; } - iter->struct_num_fields = field->content_length; - if (field->length > 0) { // fixed array @@ -283,7 +297,7 @@ const char* ps_deserialize_iterate(struct ps_deserialize_iterator* iter, const s } else { - int field_size = GetFieldSize(field->type); + int field_size = GetFieldSize(field); // now handle length if (field->length > 0) @@ -304,11 +318,11 @@ const char* ps_deserialize_iterate(struct ps_deserialize_iterator* iter, const s return position; } -static uint64_t print_field(int type, const char** ptr) +static uint64_t print_field(const struct ps_msg_field_t* field, const char** ptr, const struct ps_message_definition_t* definition) { uint64_t value = 0; // non dynamic types - switch (type) + switch (field->type) { case FT_Int8: printf("%i", (int)*(int8_t*)*ptr); @@ -358,9 +372,40 @@ static uint64_t print_field(int type, const char** ptr) printf("%lf", *(double*)*ptr); *ptr += 8; break; + case FT_ArrayString: + printf("\"%s\"", (char*)*ptr); + *ptr += field->string_length; + break; default: printf("ERROR: unhandled field type when parsing....\n"); } + + if (field->flags == FF_ENUM) + { + const char* name = "Enum Not Found"; + for (unsigned int i = 0; i < definition->num_enums; i++) + { + if (&definition->fields[definition->enums[i].field] == field && value == definition->enums[i].value) + { + name = definition->enums[i].name; + } + } + printf(" (%s)", name); + } + else if (field->flags == FF_BITMASK) + { + const char* name = "Enum Not Found"; + printf(" ("); + for (unsigned int i = 0; i < definition->num_enums; i++) + { + if (&definition->fields[definition->enums[i].field] == field && (definition->enums[i].value & value) != 0) + { + name = definition->enums[i].name; + printf("%s, ", name); + } + } + printf(")"); + } return value; } @@ -368,7 +413,6 @@ void ps_deserialize_print(const void * data, const struct ps_message_definition_ { struct ps_deserialize_iterator iter = ps_deserialize_start(data, definition); const struct ps_msg_field_t* field; uint32_t length; const char* ptr; - int it = 0; while (ptr = ps_deserialize_iterate(&iter, &field, &length)) { if (field_name && strcmp(field_name, field->name) != 0) @@ -380,7 +424,7 @@ void ps_deserialize_print(const void * data, const struct ps_message_definition_ // strings are already null terminated if (field->length == 1) { - printf("%s: %s\n", field->name, ptr); + printf("%s: \"%s\"\n", field->name, ptr); } else { @@ -423,9 +467,30 @@ void ps_deserialize_print(const void * data, const struct ps_message_definition_ printf(" { "); for (int f = 0; f < iter.struct_num_fields; f++) { - printf("%s: ", (field+1+f)->name); - print_field((field+1+f)->type, &ptr); - printf(" "); + const struct ps_msg_field_t* sf = definition->fields + 1 + field->struct_index + f; + printf("%s: ", sf->name); + if (sf->length > 1) + { + printf("["); + for (int i = 0; i < sf->length; i++) + { + print_field(sf, &ptr, definition); + + if (i != sf->length - 1) + { + printf(", "); + } + } + printf("]"); + } + else + { + print_field(sf, &ptr, definition); + } + if (f != (iter.struct_num_fields - 1)) + { + printf(", "); + } } printf(" }"); } @@ -463,38 +528,7 @@ void ps_deserialize_print(const void * data, const struct ps_message_definition_ for (unsigned int i = 0; i < length; i++) { - uint64_t value = 0; - value = print_field(field->type, &ptr); - - if (field->flags > 0) - { - if (field->flags == FF_ENUM) - { - const char* name = "Enum Not Found"; - for (unsigned int i = 0; i < definition->num_enums; i++) - { - if (definition->enums[i].field == it && value == definition->enums[i].value) - { - name = definition->enums[i].name; - } - } - printf(" (%s)", name); - } - else if (field->flags == FF_BITMASK) - { - const char* name = "Enum Not Found"; - printf(" ("); - for (unsigned int i = 0; i < definition->num_enums; i++) - { - if (definition->enums[i].field == it && (definition->enums[i].value & value) != 0) - { - name = definition->enums[i].name; - printf("%s, ", name); - } - } - printf(")"); - } - } + print_field(field, &ptr, definition); if (field->length == 1) { @@ -510,7 +544,6 @@ void ps_deserialize_print(const void * data, const struct ps_message_definition_ } } } - it++; } } @@ -523,8 +556,6 @@ const char* TypeToString(ps_field_types type) { switch (type) { - // declarations - // . . . case FT_Int8: return "int8"; case FT_Int16: @@ -546,11 +577,13 @@ const char* TypeToString(ps_field_types type) case FT_Float64: return "double"; case FT_String: - // statements executed if the expression equals the - // value of this constant_expression return "string"; case FT_Struct: return "struct"; + case FT_ArrayString: + return "astring"; + case FT_StructDefinition: + return "struct"; default: return "Unknown Type"; } @@ -565,42 +598,49 @@ void ps_print_definition(const struct ps_message_definition_t* definition, bool int end_tab = 0; for (unsigned int i = 0; i < definition->num_fields; i++) { + const struct ps_msg_field_t* field = &definition->fields[i]; + + const char* tab = i < end_tab ? " " : ""; + //print out any relevant enums for (unsigned int j = 0; j < definition->num_enums; j++) { if (definition->enums[j].field == i) { - printf("%s %i\n", definition->enums[j].name, definition->enums[j].value); + printf("%s%s %i\n", tab, definition->enums[j].name, definition->enums[j].value); } } - const char* type_name = ""; - ps_field_types type = definition->fields[i].type; - if (definition->fields[i].flags == FF_ENUM) + ps_field_types type = field->type; + const char* flag = ""; + if (field->flags == FF_ENUM) { - printf("enum "); + flag = "enum "; } - else if (definition->fields[i].flags == FF_BITMASK) + else if (field->flags == FF_BITMASK) { - printf("bitmask "); + flag = "bitmask "; } - const char* tab = i < end_tab ? " " : ""; - if (definition->fields[i].length > 1) + if (field->type == FT_StructDefinition) + { + printf("%s%s%s %s\n", tab, flag, TypeToString(field->type), field->name); + } + else if (field->length > 1) { - printf("%s%s %s[%i]\n", tab, TypeToString(definition->fields[i].type), definition->fields[i].name, definition->fields[i].length); + printf("%s%s%s %s[%i]\n", tab, flag, TypeToString(field->type), field->name, field->length); } - else if (definition->fields[i].length == 0) + else if (field->length == 0) { - printf("%s%s[] %s\n", tab, TypeToString(definition->fields[i].type), definition->fields[i].name);// dynamic array + printf("%s%s%s %s[]\n", tab, flag, TypeToString(field->type), field->name);// dynamic array } else { - printf("%s%s %s\n", tab, TypeToString(definition->fields[i].type), definition->fields[i].name); + printf("%s%s%s %s\n", tab, flag, TypeToString(field->type), field->name); } - if (definition->fields[i].type == FT_Struct) + if (field->type == FT_StructDefinition) { - end_tab = i + 1 + definition->fields[i].content_length; + end_tab = i + 1 + field->length; } } } @@ -608,21 +648,36 @@ void ps_print_definition(const struct ps_message_definition_t* definition, bool void ps_msg_alloc(unsigned int size, struct ps_allocator_t* allocator, struct ps_msg_t* out_msg) { out_msg->len = size; - if (allocator) - { - out_msg->data = (void*)allocator->alloc(size + sizeof(struct ps_msg_header), allocator->context); - } - else - { - out_msg->data = (void*)((char*)malloc(size + sizeof(struct ps_msg_header))); - } + if (allocator) + { + out_msg->data = (void*)allocator->alloc(size + sizeof(struct ps_msg_header), allocator->context); + } + else + { + out_msg->data = (void*)((char*)malloc(size + sizeof(struct ps_msg_header))); + } +} + +void ps_msg_ref_add(struct ps_msg_ref_t* msg) +{ + msg->refcount++; +} + +void ps_msg_ref_free(struct ps_msg_ref_t* msg, struct ps_allocator_t* allocator) +{ + msg->refcount--; + if (msg->refcount == 0) + { + allocator->free(msg->data, allocator->context); + allocator->free(msg, allocator->context); + } } -struct ps_msg_t ps_msg_cpy(const struct ps_msg_t* msg) +struct ps_msg_t ps_msg_cpy(const struct ps_msg_t* msg, struct ps_allocator_t* allocator) { struct ps_msg_t out; - ps_msg_alloc(msg->len, 0, &out); + ps_msg_alloc(msg->len, allocator, &out); memcpy(ps_get_msg_start(out.data), ps_get_msg_start(msg->data), msg->len); return out; } diff --git a/src/Subscriber.c b/src/Subscriber.c index 19119f1..241baeb 100644 --- a/src/Subscriber.c +++ b/src/Subscriber.c @@ -7,38 +7,42 @@ #ifndef ANDROID #include #endif +#include #include -void ps_sub_enqueue(struct ps_sub_t* sub, void* data, int data_size, const struct ps_msg_info_t* message_info) +void ps_sub_receive(struct ps_sub_t* sub, void* encoded_message, int data_size, bool is_reference, const struct ps_msg_info_t* message_info) { - // Implement a LIFO queue. Is this the best option? - int new_start = sub->queue_start - 1; - if (new_start < 0) - { - new_start += sub->queue_size; - } + // if is_reference is true, we must make a copy for the subscriber to own + // okay, so how do we let the callback specify if it wants a copy or not? - // If no queue size, just run the callback immediately - if (sub->queue_size == 0) + //todo make this not always require owning the data + //how do I release a "loaned" message? also, how do I loan? + if (sub->cb_raw) { - sub->cb(data, data_size, sub->cb_data, message_info); + // todo can avoid the copy if the allocator is used + void* out_data; + if (is_reference) + { + out_data = sub->allocator->alloc(data_size, sub->allocator->context); + memcpy(out_data, encoded_message, data_size); + } + else + { + out_data = encoded_message; + } + sub->cb_raw(out_data, data_size, sub->cb_data, message_info); } - // Handle replacement if the queue is full - else if (sub->queue_size == sub->queue_len) + + if (sub->cb) { - // we'll replace the item at the back by shifting the queue around - free(sub->queue[sub->queue_start]); - // add at the front - sub->queue[new_start] = data; - sub->queue_start = new_start; + void* out_data = sub->type->decode(encoded_message, sub->allocator); + sub->cb(out_data, data_size, sub->cb_data, message_info); } - else + + if (!is_reference && !sub->cb_raw) { - // add to the front - sub->queue_len++; - sub->queue[new_start] = data; - sub->queue_start = new_start; + sub->allocator->free(encoded_message, sub->allocator->context); } } @@ -51,63 +55,31 @@ void ps_sub_destroy(struct ps_sub_t* sub) for (unsigned int i = 0; i < sub->node->num_transports; i++) { sub->node->transports[i].unsubscribe(&sub->node->transports[i], sub); - } + } //remove it from my list of subs sub->node->num_subs--; - if (sub->node->num_subs == 0) - { - free(sub->node->subs); - sub->node->subs = 0; - } - else - { - struct ps_sub_t** old_subs = sub->node->subs; - sub->node->subs = (struct ps_sub_t**)malloc(sizeof(struct ps_sub_t*)*sub->node->num_subs); - int ind = 0; - for (unsigned int i = 0; i < sub->node->num_subs+1; i++) - { - if (old_subs[i] == sub) - { - //skip me - } - else - { - sub->node->subs[ind++] = old_subs[i]; - } - } - free(old_subs); - } - - // free any queued up received messages and the queue itself - for (int i = 0; i < sub->queue_len; i++) + if (sub->node->num_subs == 0) { - int index = (sub->queue_start + i)%sub->queue_size; - if (sub->queue[index] != 0) - free(sub->queue[index]); + free(sub->node->subs); + sub->node->subs = 0; } - free(sub->queue); -} - -void* ps_sub_deque(struct ps_sub_t* sub) -{ - if (sub->queue_len == 0) + else { - //printf("Warning: dequeued when there was nothing in queue\n"); - return 0; - } - - // we are dequeueing, so remove the newest first (from the front) - sub->queue_len--; - - void* data = sub->queue[sub->queue_start]; - sub->queue[sub->queue_start] = 0; - - int new_start = sub->queue_start+1; - if (new_start >= sub->queue_size) - { - new_start -= sub->queue_size; - } - sub->queue_start = new_start; - return data; + struct ps_sub_t** old_subs = sub->node->subs; + sub->node->subs = (struct ps_sub_t**)malloc(sizeof(struct ps_sub_t*)*sub->node->num_subs); + int ind = 0; + for (unsigned int i = 0; i < sub->node->num_subs+1; i++) + { + if (old_subs[i] == sub) + { + //skip me + } + else + { + sub->node->subs[ind++] = old_subs[i]; + } + } + free(old_subs); + } } diff --git a/src/TCPTransport.c b/src/TCPTransport.c new file mode 100644 index 0000000..910125c --- /dev/null +++ b/src/TCPTransport.c @@ -0,0 +1,885 @@ +// Must be included first on windows +#include + +#include +#include +#include +#include +#include +//#include + +#include +#include +#include + +#ifdef __unix__ +#include +#endif + +void remove_client_socket(struct ps_tcp_transport_impl* transport, int socket, struct ps_node_t* node) +{ + // find the index + int i = 0; + for (; i < transport->num_clients; i++) + { + if (transport->clients[i].socket == socket)// socket packed in address + { + break; + } + } + +#ifdef _WIN32 + closesocket(socket); +#else + close(socket); +#endif + + if (transport->clients[i].packet_data) + { + free(transport->clients[i].packet_data); + } + + if (transport->clients[i].queued_message) + { + ps_msg_ref_free(transport->clients[i].queued_message, transport->clients[i].publisher->allocator); + } + + // free queued messages + if (transport->clients[i].num_queued_messages) + { + for (int j = 0; j < transport->clients[i].num_queued_messages; j++) + { + ps_msg_ref_free(transport->clients[i].queued_messages[j].msg, transport->clients[i].publisher->allocator); + } + free(transport->clients[i].queued_messages); + } + + struct ps_tcp_client_t* old_clients = transport->clients; + transport->num_clients -= 1; + + // close the socket and dont wait on it anymore + ps_event_set_remove_socket(&node->events, transport->clients[i].socket); + + if (transport->num_clients) + { + transport->clients = (struct ps_tcp_client_t*)malloc(sizeof(struct ps_tcp_client_t) * transport->num_clients); + for (int j = 0; j < i; j++) + { + transport->clients[j] = old_clients[j]; + } + + for (int j = i + 1; j <= transport->num_clients; j++) + { + transport->clients[j - 1] = old_clients[j]; + } + } + free(old_clients); +} + +void ps_tcp_remove_connection(struct ps_tcp_transport_impl* impl, int index) +{ + // Free our subscribers and any buffers + int iter = 0; + int new_size = impl->num_connections - 1; + struct ps_tcp_transport_connection* new_connections = new_size == 0 ? 0 : (struct ps_tcp_transport_connection*)malloc(sizeof(struct ps_tcp_transport_connection) * new_size); + for (int i = 0; i < impl->num_connections; i++) + { + if (i != index) + { + new_connections[iter++] = impl->connections[i]; + continue; + } + + if (!impl->connections[i].waiting_for_header) + { + free(impl->connections[i].packet_data); + } + ps_event_set_remove_socket(&impl->node->events, impl->connections[i].socket); +#ifdef _WIN32 + closesocket(impl->connections[i].socket); +#else + close(impl->connections[i].socket); +#endif + } + impl->num_connections = new_size; + free(impl->connections); + impl->connections = new_connections; +} + +int ps_tcp_transport_spin(struct ps_transport_t* transport, struct ps_node_t* node) +{ + struct ps_tcp_transport_impl* impl = (struct ps_tcp_transport_impl*)transport->impl; + int socket = accept(impl->socket, 0, 0); + if (socket > 0) + { +#ifdef PUBSUB_VERBOSE + printf("Got new socket connection!\n"); +#endif + + // add it to the list yo + impl->num_clients++; + struct ps_tcp_client_t* old_sockets = impl->clients; + impl->clients = (struct ps_tcp_client_t*)malloc(sizeof(struct ps_tcp_client_t) * impl->num_clients); + for (int i = 0; i < impl->num_clients - 1; i++) + { + impl->clients[i] = old_sockets[i]; + } + struct ps_tcp_client_t* new_client = &impl->clients[impl->num_clients - 1]; + new_client->socket = socket; + new_client->needs_removal = false; + new_client->current_packet_size = 0; + new_client->desired_packet_size = 0; + new_client->packet_data = 0; + new_client->queued_message = 0; + new_client->queued_message_length = 0; + new_client->queued_message_written = 0; + new_client->queued_messages = 0; + new_client->num_queued_messages = 0; + + // set non-blocking +#ifdef _WIN32 + DWORD nonBlocking = 1; + if (ioctlsocket(socket, FIONBIO, &nonBlocking) != 0) + { + printf("Failed to Set Socket as Non-Blocking!\n"); + closesocket(socket); + return 0; + } +#endif +#ifdef ARDUINO + fcntl(socket, F_SETFL, O_NONBLOCK); +#endif +#ifdef __unix__ + int flags = fcntl(socket, F_GETFL); + fcntl(socket, F_SETFL, flags | O_NONBLOCK); +#endif + + ps_event_set_add_socket(&node->events, socket); + + if (impl->num_clients - 1) + { + free(old_sockets); + } + } + //printf("polled\n"); + +// remove any old sockets + for (int i = 0; i < impl->num_clients; i++) + { + if (impl->clients[i].needs_removal) + { + // add it to the list + struct ps_client_t client; + client.endpoint.address = impl->clients[i].socket; + client.endpoint.port = 255;// p->port; + client.stream_id = 0; + ps_pub_remove_client(impl->clients[i].publisher, &client);// todo this is probably unsafe.... + + remove_client_socket(impl, impl->clients[i].socket, impl->clients[i].publisher->node); + + i = i - 1; + break; + } + } + + // update our sockets yo + for (int i = 0; i < impl->num_clients; i++) + { + struct ps_tcp_client_t* client = &impl->clients[i]; + + // send queued messages until we block or cant send anymore + while (client->queued_message != 0) + { + int to_send = client->queued_message_length - client->queued_message_written; + char* data = (char*)client->queued_message->data; + int sent = send(client->socket, &data[client->queued_message_written], to_send, 0); + if (sent > 0) + { + client->queued_message_written += sent; + } +#ifdef WIN32 + else if (sent < 0 && WSAGetLastError() != WSAEWOULDBLOCK) +#else + else if (sent < 0 && errno != EAGAIN) +#endif + { + //printf("needs removal %i\n", errno); + client->needs_removal = true; + } + + //printf("Sending more: %i to make %i of %i\n", sent, client->queued_message_written, client->queued_message_length); + + if (client->queued_message_written == client->queued_message_length) + { + //printf("Message sent.\n"); + ps_msg_ref_free(client->queued_message, client->publisher->allocator); + client->queued_message = 0; + + // we finished! check if there are more to send + if (client->num_queued_messages > 0) + { + // grab a message from the front of our message queue + client->queued_message = client->queued_messages[0].msg; + client->queued_message_written = 0; + client->queued_message_length = client->queued_messages[0].msg->len + sizeof(struct ps_msg_header); + + client->num_queued_messages -= 1; + if (client->num_queued_messages == 0) + { + free(client->queued_messages); + client->queued_messages = 0; + continue; + } + + struct ps_tcp_client_queued_message_t* msgs = (struct ps_tcp_client_queued_message_t*)malloc(client->num_queued_messages * sizeof(struct ps_tcp_client_queued_message_t)); + for (int i = 0; i < client->num_queued_messages; i++) + { + msgs[i] = client->queued_messages[i + 1];// take from the front + } + free(client->queued_messages); + client->queued_messages = msgs; + + // continue so we can attempt to send again + } + else + { + ps_event_set_remove_socket_write(&node->events, client->socket); + break;// no more to send + } + } + else + { + break;// we couldnt send anymore atm + } + } + + // check for new data and add it to the packet if present + char buf[1500]; + // if we havent gotten a header yet, just check for that + if (client->desired_packet_size == 0) + { + const int header_size = sizeof(struct ps_msg_header); + int len = recv(client->socket, buf, header_size, MSG_PEEK); + //printf("peek %i desired size %i\n", len, header_size); + if (len == 0) + { + client->needs_removal = true; + continue; + } + if (len < header_size) + { + continue;// no header yet + } + + char message_type = buf[0];// not used atm + + // we actually got the header! start looking for the message + len = recv(client->socket, buf, header_size, 0); + //connection->packet_type = message_type; + //printf("recv %i from client->socket desired size 0 2\n", len); + //client->waiting_for_header = false; + client->desired_packet_size = *(uint32_t*)&buf[1]; + //printf("Incoming message with %i bytes\n", client->desired_packet_size); + client->packet_data = (char*)malloc(client->desired_packet_size); + + client->current_packet_size = 0; + } + // read in the message + if (client->desired_packet_size != 0) + { + int remaining_size = client->desired_packet_size - client->current_packet_size; + // check for new messages and read until we hit packet size + int len = recv(client->socket, &client->packet_data[client->current_packet_size], remaining_size, 0); + //printf("recv %i from client->socket\n", len); + if (len > 0) + { + //printf("Read %i bytes of message\n", len); + client->current_packet_size += len; + + if (client->current_packet_size == client->desired_packet_size) + { +#ifdef PUBSUB_VERBOSE + printf("message finished\n"); +#endif + + if (true)// todo look at message id + { + // its a subscribe + const char* topic = &client->packet_data[4]; + // check if this matches any of our publishers + for (unsigned int pi = 0; pi < node->num_pubs; pi++) + { + struct ps_pub_t* pub = node->pubs[pi]; + if (strcmp(topic, pub->topic) == 0) + { + uint32_t skip = *(uint32_t*)&client->packet_data[0]; + // send response and start publishing + struct ps_client_t sub_client; + sub_client.endpoint.address = client->socket; + sub_client.endpoint.port = 255;// p->port; + sub_client.last_keepalive = 10000000000000;//GetTickCount64();// use the current time stamp + sub_client.sequence_number = 0; + sub_client.stream_id = 0; + sub_client.modulo = skip > 0 ? skip + 1 : 0; + sub_client.transport = transport; + + impl->clients[i].publisher = pub; + + // send the client the acknowledgement and message definition + char buf[1500]; + int32_t length = ps_serialize_message_definition((void*)buf, pub->message_definition); + struct ps_msg_header hdr; + hdr.pid = PS_TCP_PROTOCOL_MESSAGE_DEFINITION;// message definition + hdr.length = length; + hdr.id = hdr.seq = 0; + send(impl->clients[i].socket, (char*)&hdr, sizeof(hdr), 0); + send(impl->clients[i].socket, buf, length, 0); + +#ifdef PUBSUB_VERBOSE + printf("TCPTransport: Got subscribe request, adding client if we haven't already\n"); +#endif + ps_pub_add_client(pub, &sub_client); + + break; + } + } + } + + free(client->packet_data); + client->packet_data = 0; + client->desired_packet_size = 0; + } + } + } + } + + int message_count = 0; + for (int i = 0; i < impl->num_connections; i++) + { + struct ps_tcp_transport_connection* connection = &impl->connections[i]; + char buf[1500]; + if (connection->connecting) + { + //printf("checking for connected\n"); + // select to check for writability + fd_set wfds; + struct timeval tv; + int retval; + + FD_ZERO(&wfds); + FD_SET(connection->socket, &wfds); + + tv.tv_sec = 0; + tv.tv_usec = 0; + retval = select(connection->socket + 1, NULL, &wfds, NULL, &tv); + //printf("select\n"); + if (retval == -1) + { + // error? + printf("socket errored while connecting\n"); + } + else if (retval) + { + // socket is writable + //printf("socket writable\n"); + + // make the subscribe request in a "packet" + // a packet is an int length followed by data + int32_t length = strlen(connection->subscriber->topic) + 1 + 4; + + struct ps_msg_header hdr; + hdr.pid = 0x01; + hdr.length = length; + hdr.id = hdr.seq = 0; + send(connection->socket, (char*)&hdr, sizeof(hdr), 0); + + // make the request + uint32_t skip = connection->subscriber->skip; + send(connection->socket, (char*)&skip, 4, 0); + send(connection->socket, connection->subscriber->topic, length - 4, 0); + + connection->connecting = false; + } + } + // if we havent gotten a header yet, just check for that + else if (connection->waiting_for_header) + { + const int header_size = sizeof(struct ps_msg_header); + int len = recv(connection->socket, buf, header_size, MSG_PEEK); + //printf("peek got: %i\n", len); + if (len == 0) + { + // we got disconnected + ps_tcp_remove_connection(impl, i); + i--; + continue; + } + else if (len < header_size) + { + continue;// no header yet + } + + char message_type = buf[0]; + + // we actually got the header! start looking for the message + len = recv(connection->socket, buf, header_size, 0); + connection->packet_type = message_type; + connection->waiting_for_header = false; + connection->packet_size = *(uint32_t*)&buf[1]; + //printf("Incoming message with %i bytes\n", impl->connections[i].packet_size); + connection->packet_data = (char*)connection->subscriber->allocator->alloc(connection->packet_size, connection->subscriber->allocator->context); + + connection->current_size = 0; + } + else // read in the message + { + int remaining_size = connection->packet_size - connection->current_size; + + // check for new messages and read until we hit packet size + int len = recv(connection->socket, &connection->packet_data[connection->current_size], remaining_size, 0); + //printf("len %i\n", len); + if (len == 0) + { + // we got disconnected + ps_tcp_remove_connection(impl, i); + i--; + continue; + } + else if (len > 0) + { + connection->current_size += len; + //printf("Read %i bytes of message, so far: %i\n", len, connection->current_size); + + if (connection->current_size == connection->packet_size) + { + //printf("message finished type %x\n", connection->packet_type); + if (connection->packet_type == PS_TCP_PROTOCOL_MESSAGE_DEFINITION) + { + //printf("Was message definition\n"); + if (connection->subscriber->type == 0) + { + // todo put this in a function so we cant accidentally forget it + if (connection->subscriber->received_message_def.fields == 0) + { + ps_deserialize_message_definition(connection->packet_data, &connection->subscriber->received_message_def); + } + + // call the callback as well + if (node->def_cb) + { + node->def_cb(&connection->subscriber->received_message_def, node->def_cb_data); + } + } + + connection->subscriber->allocator->free(connection->packet_data, connection->subscriber->allocator->context); + } + else if (connection->packet_type == PS_TCP_PROTOCOL_DATA) + { + //printf("added to queue\n"); + // decode and add it to the queue + struct ps_msg_info_t message_info; + message_info.address = connection->endpoint.address; + message_info.port = connection->endpoint.port; + + ps_sub_receive(connection->subscriber, connection->packet_data, connection->packet_size, false, &message_info); + + // remove the reference to packet data so we don't try and double free it on destroy + connection->packet_data = 0; + message_count++; + } + else + { + // unhandled packet id + connection->subscriber->allocator->free(connection->packet_data, connection->subscriber->allocator->context); + } + connection->waiting_for_header = true; + } + } + } + } + return message_count; +} + +void ps_tcp_transport_pub(struct ps_transport_t* transport, struct ps_pub_t* publisher, struct ps_client_t* client, struct ps_msg_ref_t* msg) +{ + // todo dont + int length = msg->len; + void* message = msg->data; + struct ps_tcp_transport_impl* impl = (struct ps_tcp_transport_impl*)transport->impl; + + // the client packs the socket id in the addr + int socket = client->endpoint.address; + + // okay, so new version, if any write fails (EAGAIN or < expected size) + // we make a copy of the entire message and store it on that client to try and send again in our update loop + // if we get into this function and this client already has a queued message, just drop this one + struct ps_tcp_client_t* tclient = 0; + for (int i = 0; i < impl->num_clients; i++) + { + if (impl->clients[i].socket == socket) + { + tclient = &impl->clients[i]; + break; + } + } + + if (tclient->queued_message != 0) + { + // check if we have queue space left + + // for now hardcode max queue size + const int max_queue_size = 10; + + // add a reference to the message and queue it up + ps_msg_ref_add(msg); + + // this if statement is unnecessary, but I added it for the sake of testing/completeness + if (tclient->queued_message == 0) + { + tclient->queued_message = msg; + tclient->queued_message_length = length + sizeof(struct ps_msg_header); + tclient->queued_message_written = 0; + } + else if (tclient->num_queued_messages >= max_queue_size) + { + // todo use a deque lol + // swap everything down, freeing the first + for (int i = tclient->num_queued_messages - 1; i >= 1; i--) + { + tclient->queued_messages[i] = tclient->queued_messages[i - 1]; + } + tclient->queued_messages[0].msg = msg; + printf("dropped message on topic '%s'\n", publisher->topic); + return;// drop it, we are out of queue space + } + else + { + //printf("queuing up message %i on topic '%s'\n", tclient->num_queued_messages, publisher->topic); + + // add the message to the front of the queue + tclient->num_queued_messages += 1; + struct ps_tcp_client_queued_message_t* msgs = (struct ps_tcp_client_queued_message_t*)malloc(tclient->num_queued_messages * sizeof(struct ps_tcp_client_queued_message_t)); + + msgs[0].msg = msg; + for (int i = 0; i < tclient->num_queued_messages - 1; i++) + { + msgs[i + 1] = tclient->queued_messages[i]; + } + free(tclient->queued_messages); + tclient->queued_messages = msgs; + + return; + } + } + //printf("started writing\n"); + // try and write, if any of these fail, make a copy + + // the message header is already filled out with the packet id and length + + int32_t desired_len = sizeof(struct ps_msg_header) + length; + //printf("trying to send message of %i bytes\n", desired_len); + int32_t c = send(socket, (char*)message, desired_len, 0); + if (c < desired_len && c >= 0) + { + tclient->queued_message_written = c; + goto FAILCOPY; + } + if (c < 0) + { +#ifdef WIN32 + int error = WSAGetLastError(); + if (error == WSAEWOULDBLOCK) +#else + if (errno == EAGAIN || errno == EWOULDBLOCK) +#endif + { + tclient->queued_message_written = 0; + goto FAILCOPY; + } + goto FAILDISCONNECT; + } + + //printf("wrote all\n"); + return; + + char* data; +FAILDISCONNECT: + //printf("Disconnected: %s\n", strerror(err)); + tclient->needs_removal = true; + return; + +FAILCOPY: + // add a reference count and put it in our queue + ps_msg_ref_add(msg); + + //printf("Wrote %i bytes\n", tclient->queued_message_written); + + tclient->queued_message = msg; + tclient->queued_message_length = length + sizeof(struct ps_msg_header); + ps_event_set_add_socket_write(&publisher->node->events, socket); + return; +} + +void ps_tcp_transport_subscribe(struct ps_transport_t* transport, struct ps_sub_t* subscriber, struct ps_endpoint_t* ep, uint32_t transport_info) +{ + struct ps_tcp_transport_impl* impl = (struct ps_tcp_transport_impl*)transport->impl; + + // check if we already have a sub for this subscriber with this endpoint + // if so, ignore it + for (int i = 0; i < impl->num_connections; i++) + { + if (impl->connections[i].endpoint.port == ep->port && + impl->connections[i].endpoint.address == ep->address && + impl->connections[i].subscriber->sub_id == subscriber->sub_id) + { + return; + } + } + +#ifdef _WIN32 + SOCKET sock = socket(AF_INET, SOCK_STREAM, 0); +#else + int sock = socket(AF_INET, SOCK_STREAM, 0); +#endif + + // set non-blocking +#ifdef _WIN32 + DWORD nonBlocking = 1; + if (ioctlsocket(sock, FIONBIO, &nonBlocking) != 0) + { + ps_print_socket_error("Failed to Set Socket as Non-Blocking"); + closesocket(sock); + return; + } +#endif +#ifdef ARDUINO + fcntl(sock, F_SETFL, O_NONBLOCK); +#endif +#ifdef __unix__ + int flags = fcntl(sock, F_GETFL); + fcntl(sock, F_SETFL, flags | O_NONBLOCK); +#endif + + // Actually connect + //printf("connecting\n"); + struct sockaddr_in server_addr; + server_addr.sin_family = AF_INET; + server_addr.sin_addr.s_addr = htonl(ep->address); + server_addr.sin_port = htons(transport_info); + int connect_result = connect(sock, (struct sockaddr*)&server_addr, sizeof(server_addr)); + if (connect_result != 0) + { +#ifdef _WIN32 + if (WSAGetLastError() != WSAEWOULDBLOCK)//WSAEINPROGRESS +#else + if (errno != EINPROGRESS) +#endif + { + ps_print_socket_error("error connecting tcp socket"); + return; + } + } + + //printf("%i %i %i %i\n", (ep->address & 0xFF000000) >> 24, (ep->address & 0xFF0000) >> 16, (ep->address & 0xFF00) >> 8, (ep->address & 0xFF)); + + // make the subscribe request in a "packet" + // a packet is an int length followed by data + /*int8_t packet_type = 0x01;//subscribe + send(sock, (char*)&packet_type, 1, 0); + + int32_t length = strlen(subscriber->topic) + 1 + 4; + send(sock, (char*)&length, 4, 0); + + // make the request + char buffer[500]; + strcpy(buffer, subscriber->topic); + uint32_t skip = subscriber->skip; + send(sock, (char*)&skip, 4, 0); + send(sock, buffer, length - 4, 0);*/ + + // add the socket to the list of connections + impl->num_connections++; + struct ps_tcp_transport_connection* old_connections = impl->connections; + impl->connections = (struct ps_tcp_transport_connection*)malloc(sizeof(struct ps_tcp_transport_connection) * impl->num_connections); + for (int i = 0; i < impl->num_connections - 1; i++) + { + impl->connections[i] = old_connections[i]; + } + + struct ps_tcp_transport_connection* new_connection = &impl->connections[impl->num_connections - 1]; + new_connection->socket = sock; + new_connection->endpoint = *ep; + new_connection->waiting_for_header = true; + new_connection->subscriber = subscriber; + new_connection->connecting = true; + + ps_event_set_add_socket(&subscriber->node->events, sock); + + if (impl->num_connections - 1) + { + free(old_connections); + } +} + +void ps_tcp_transport_unsubscribe(struct ps_transport_t* transport, struct ps_sub_t* subscriber) +{ + struct ps_tcp_transport_impl* impl = (struct ps_tcp_transport_impl*)transport->impl; + + // remove all transports which reference this subscriber + int num_to_remove = 0; + for (int i = 0; i < impl->num_connections; i++) + { + if (impl->connections[i].subscriber == subscriber) + { + num_to_remove++; + } + } + +#ifdef PUBSUB_VERBOSE + printf("Removing %i tcp subs\n", num_to_remove); +#endif + + if (num_to_remove > 0) + { + // Free our subscribers and any buffers + int iter = 0; + int new_size = impl->num_connections - num_to_remove; + struct ps_tcp_transport_connection* new_connections = new_size == 0 ? 0 : (struct ps_tcp_transport_connection*)malloc(sizeof(struct ps_tcp_transport_connection) * new_size); + for (int i = 0; i < impl->num_connections; i++) + { + if (impl->connections[i].subscriber != subscriber) + { + new_connections[iter++] = impl->connections[i]; + continue; + } + + if (!impl->connections[i].waiting_for_header) + { + struct ps_sub_t* sub = impl->connections[i].subscriber; + sub->allocator->free(impl->connections[i].packet_data, sub->allocator->context); + } + ps_event_set_remove_socket(&impl->node->events, impl->connections[i].socket); +#ifdef _WIN32 + closesocket(impl->connections[i].socket); +#else + close(impl->connections[i].socket); +#endif + } + impl->num_connections = new_size; + free(impl->connections); + impl->connections = new_connections; + } +} + + +void ps_tcp_transport_destroy(struct ps_transport_t* transport) +{ + struct ps_tcp_transport_impl* impl = (struct ps_tcp_transport_impl*)transport->impl; + + // Free our subscribers and any buffers + for (int i = 0; i < impl->num_connections; i++) + { + if (!impl->connections[i].waiting_for_header) + { + struct ps_sub_t* sub = impl->connections[i].subscriber; + sub->allocator->free(impl->connections[i].packet_data, sub->allocator->context); + } + ps_event_set_remove_socket(&impl->node->events, impl->connections[i].socket); +#ifdef _WIN32 + closesocket(impl->connections[i].socket); +#else + close(impl->connections[i].socket); +#endif + } + + for (int i = 0; i < impl->num_clients; i++) + { + ps_event_set_remove_socket(&impl->node->events, impl->clients[i].socket); +#ifdef _WIN32 + closesocket(impl->clients[i].socket); +#else + close(impl->clients[i].socket); +#endif + } + +#ifdef _WIN32 + closesocket(impl->socket); +#else + close(impl->socket); +#endif + + if (impl->num_clients) + { + free(impl->clients); + } + + if (impl->num_connections) + { + free(impl->connections); + } + + free(impl); +} + +void ps_tcp_transport_init(struct ps_transport_t* transport, struct ps_node_t* node) +{ +#ifdef __unix__ + signal(SIGPIPE, SIG_IGN); +#endif + + transport->spin = ps_tcp_transport_spin; + transport->subscribe = ps_tcp_transport_subscribe; + transport->unsubscribe = ps_tcp_transport_unsubscribe; + transport->destroy = ps_tcp_transport_destroy; + transport->pub = ps_tcp_transport_pub; + transport->uuid = 1; + + struct ps_tcp_transport_impl* impl = (struct ps_tcp_transport_impl*)malloc(sizeof(struct ps_tcp_transport_impl)); + + impl->num_clients = 0; + impl->num_connections = 0; + + impl->node = node; + + impl->socket = socket(AF_INET, SOCK_STREAM, 0); + + struct sockaddr_in server_addr; + server_addr.sin_family = AF_INET; + server_addr.sin_addr.s_addr = INADDR_ANY; + server_addr.sin_port = 0;// we want an ephemeral port + if (bind(impl->socket, (struct sockaddr*)&server_addr, sizeof(server_addr)) < 0) + { + ps_print_socket_error("error binding tcp transport socket"); + } + + socklen_t outlen = sizeof(struct sockaddr_in); + struct sockaddr_in outaddr; + getsockname(impl->socket, (struct sockaddr*)&outaddr, &outlen); + transport->transport_info = ntohs(outaddr.sin_port); + + printf("Bound tcp to %i\n", transport->transport_info); + + // set non-blocking +#ifdef _WIN32 + DWORD nonBlocking = 1; + if (ioctlsocket(impl->socket, FIONBIO, &nonBlocking) != 0) + { + ps_print_socket_error("Failed to Set Socket as Non-Blocking"); + closesocket(impl->socket); + return; + } +#endif +#ifdef ARDUINO + fcntl(impl->socket, F_SETFL, O_NONBLOCK); +#endif +#ifdef __unix__ + int flags = fcntl(impl->socket, F_GETFL); + fcntl(impl->socket, F_SETFL, flags | O_NONBLOCK); +#endif + + listen(impl->socket, 5); + + ps_event_set_add_socket(&node->events, impl->socket); + + transport->impl = (void*)impl; +} diff --git a/src/UDPTransport.c b/src/UDPTransport.c index 65a655f..1efba38 100644 --- a/src/UDPTransport.c +++ b/src/UDPTransport.c @@ -4,7 +4,7 @@ #include -void ps_udp_publish(struct ps_pub_t* pub, struct ps_client_t* client, struct ps_msg_t* msg) +void ps_udp_publish(struct ps_pub_t* pub, struct ps_client_t* client, struct ps_msg_ref_t* msg) { // send da udp packet! struct sockaddr_in address; @@ -16,10 +16,9 @@ void ps_udp_publish(struct ps_pub_t* pub, struct ps_client_t* client, struct ps_ //need to add in the topic id struct ps_msg_header* hdr = (struct ps_msg_header*)msg->data; hdr->pid = PS_UDP_PROTOCOL_DATA; + hdr->length = msg->len; hdr->id = client->stream_id; hdr->seq = client->sequence_number++; - hdr->index = 0; - hdr->count = 1;// todo use me for larger packets int sent_bytes = sendto(pub->node->socket, (const char*)msg->data, msg->len + sizeof(struct ps_msg_header), 0, (struct sockaddr*)&address, sizeof(struct sockaddr_in)); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 6d3c88a..68c64fd 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -6,8 +6,8 @@ foreach(test_file ${tests}) message(STATUS "create ${test_file} executable from source: ${test_source}") add_executable("${test_file}" "${test_source}") target_include_directories("${test_file}" PUBLIC ../include) - add_dependencies("${test_file}" pubsub pubsub_msgs) - target_link_libraries("${test_file}" pubsub ${CMAKE_THREAD_LIBS_INIT} pubsub_msgs) + add_dependencies("${test_file}" pubsub pubsub_msgs pubsub_cpp) + target_link_libraries("${test_file}" pubsub ${CMAKE_THREAD_LIBS_INIT} pubsub_msgs pubsub_test_msgs pubsub_cpp) # Auto populate the tests from test source file # Note : CMake must be reconfigured if tests are added/renamed/removed from test source file @@ -17,5 +17,6 @@ foreach(test_file ${tests}) STRING(REGEX REPLACE ${TEST_REGEX} "\\1" test ${test}) message(STATUS " add test: ${test}") add_test(NAME ${test_file}_${test} COMMAND ${test_file} ${test}) + set_tests_properties(${test_file}_${test} PROPERTIES TIMEOUT 10) endforeach() endforeach() diff --git a/tests/Test.msg b/tests/Test.msg new file mode 100644 index 0000000..efb9c01 --- /dev/null +++ b/tests/Test.msg @@ -0,0 +1,25 @@ +# A message that tests a wide variety of things +int32 test_int + +string test_string + +struct Struct +astring:10 test_array_string +uint8 i1 +int64 i2 +float f3 +double d4 +double a5[3] + +TYPE_INT=0 +TYPE_DOUBLE=1 +enum uint8 test_enum +end_struct + +Struct test_struct +Struct test_structs[2] +Struct test_struct_array[] + +TYPE2_FLAG1=1 +TYPE2_FLAG2=2 +bitmask uint8 test_bitmask[3] diff --git a/tests/mini_mock.hpp b/tests/mini_mock.hpp index 28c34f4..79c9ba5 100644 --- a/tests/mini_mock.hpp +++ b/tests/mini_mock.hpp @@ -46,6 +46,7 @@ #include #include #include +#include #include // Useful console colors @@ -83,6 +84,28 @@ static int mini_mock_failed_conditions_count = 0; } \ } +// If an exception matching thie provided message is thrown +// - an automatic message will be printed (with file name and line number) +// - the test continues +// - the test will fail at the end +void EXPECT_THROWS(std::function function, const std::string& message) +{ + bool threw = false; + try + { + function(); + } + catch (std::exception& err) + { + threw = true; + if (message.length()) + { + EXPECT(std::string(err.what()) == message); + } + } + EXPECT(threw); +} + // If condition is false : // - the given message will be printed // - the test fails and stops immediately diff --git a/tests/test_cli.cpp b/tests/test_cli.cpp index 325da47..4738868 100644 --- a/tests/test_cli.cpp +++ b/tests/test_cli.cpp @@ -94,13 +94,13 @@ TEST(test_cli_pub_latched, []() { bool got_message = false; pubsub::Subscriber subscriber(node, "/data", [&](const pubsub::msg::StringSharedPtr& msg) { - printf("Got message %s in sub1\n", msg->value); - EXPECT(strcmp("hello", msg->value) == 0); + printf("Got message %s in sub1\n", msg->value.c_str()); + EXPECT(msg->value == "hello"); got_message = true; spinner.stop(); }, 10); - spinner.wait(); + spinner.run(); EXPECT(got_message); run = false; @@ -120,13 +120,13 @@ TEST(test_cli_pub, []() { bool got_message = false; pubsub::Subscriber subscriber(node, "/data", [&](const pubsub::msg::StringSharedPtr& msg) { - printf("Got message %s in sub1\n", msg->value); - EXPECT(strcmp("hello", msg->value) == 0); + printf("Got message %s in sub1\n", msg->value.c_str()); + EXPECT(msg->value == "hello"); got_message = true; spinner.stop(); }, 10); - spinner.wait(); + spinner.run(); EXPECT(got_message); run = false; diff --git a/tests/test_pubsub_c.cpp b/tests/test_pubsub_c.cpp index 2b37b45..a3f87b1 100644 --- a/tests/test_pubsub_c.cpp +++ b/tests/test_pubsub_c.cpp @@ -6,23 +6,18 @@ #include #include +#include #include "mini_mock.hpp" /*add test to make sure mismatched messages are detected and not received also add a test to test subscriber/publisher numbers -add a way to add a timeout to tests - add a test for generic message handling make the pose viewer also be able to view odom in pubviz (maybe think of a way to view velocities)*/ -//lets test queue size too - -//so for that test, shove over N messages - TEST(test_publish_subscribe_generic, []() { struct ps_node_t node; ps_node_init(&node, "test_node", "", true); @@ -44,13 +39,12 @@ TEST(test_publish_subscribe_generic, []() { struct ps_subscriber_options options; ps_subscriber_options_init(&options); //options.skip = skip; - options.queue_size = 0; options.allocator = 0; options.ignore_local = false; static bool got_message = false; //options.preferred_transport = tcp ? 1 : 0; - options.cb = [](void* message, unsigned int size, void* data2, const ps_msg_info_t* info) + options.cb_raw = [](void* message, unsigned int size, void* data2, const ps_msg_info_t* info) { got_message = true; // todo need to also assert we have the message type @@ -58,8 +52,9 @@ TEST(test_publish_subscribe_generic, []() { auto data = (struct pubsub__String*)pubsub__String_decode(message, &ps_default_allocator); printf("Got message: %s\n", data->value); EXPECT(strcmp(data->value, rmsg.value) == 0); - free(data->value); - free(data); + pubsub__String_free(data, &ps_default_allocator); + + free(message); }; ps_node_create_subscriber_adv(&node, "/data", 0, &string_sub, &options); @@ -76,7 +71,7 @@ TEST(test_publish_subscribe_generic, []() { ps_node_destroy(&node); }); -void latch_test(bool broadcast, bool tcp) +void latch_test_cb(bool broadcast, bool tcp) { struct ps_node_t node; ps_node_init(&node, "test_node", "", broadcast); @@ -88,64 +83,116 @@ void latch_test(bool broadcast, bool tcp) struct ps_pub_t string_pub; ps_node_create_publisher(&node, "/data", &pubsub__String_def, &string_pub, true); - struct ps_sub_t string_sub; + // come up with the latched topic + static struct pubsub__String rmsg; + rmsg.value = "Hello"; + ps_pub_publish_ez(&string_pub, &rmsg); + + static bool got_message = false; + got_message = false; + struct ps_sub_t string_sub; struct ps_subscriber_options options; ps_subscriber_options_init(&options); options.preferred_transport = tcp ? 1 : 0;// tcp yo + options.cb = [](void* message, unsigned int size, void* cb_data, const struct ps_msg_info_t* info) + { + auto data = (struct pubsub__String*)message; + EXPECT(strcmp(data->value, rmsg.value) == 0); + pubsub__String_free(data, &ps_default_allocator); + got_message = true; + }; ps_node_create_subscriber_adv(&node, "/data", &pubsub__String_def, &string_sub, &options); - // come up with the latched topic - struct pubsub__String rmsg; - rmsg.value = "Hello"; - ps_pub_publish_ez(&string_pub, &rmsg); - - bool got_message = false; // now spin and wait for us to get the published message - while (ps_okay()) + while (ps_okay() && !got_message) { - ps_node_spin(&node);// todo blocking wait first - - struct pubsub__String* data; - while (data = (struct pubsub__String*)ps_sub_deque(&string_sub)) - { - // user is responsible for freeing the message and its arrays - printf("Got message: %s\n", data->value); - EXPECT(strcmp(data->value, rmsg.value) == 0); - got_message = true; - free(data->value); - free(data);//todo use allocator free - goto done; - } + ps_node_spin(&node); ps_sleep(1); } -done: EXPECT(got_message); + ps_node_destroy(&node); } -TEST(test_publish_subscribe_latched_multicast, []() { - latch_test(false, false); +TEST(test_publish_subscribe_latched_cb_multicast, []() { + latch_test_cb(false, false); }); -TEST(test_publish_subscribe_latched_broadcast, []() { - latch_test(true, false); +TEST(test_publish_subscribe_latched_cb_broadcast, []() { + latch_test_cb(true, false); }); -TEST(test_publish_subscribe_latched_multicast_tcp, []() { - latch_test(false, true); +TEST(test_publish_subscribe_latched_cb_multicast_tcp, []() { + latch_test_cb(false, true); }); -TEST(test_publish_subscribe_latched_broadcast_tcp, []() { - latch_test(true, true); +TEST(test_publish_subscribe_latched_cb_broadcast_tcp, []() { + latch_test_cb(true, true); }); -void latch_test_cb(bool broadcast, bool tcp) -{ +// test sending a very large message +TEST(test_publish_subscribe_large, []() { struct ps_node_t node; - ps_node_init(&node, "test_node", "", broadcast); + ps_node_init(&node, "test_node", "", true); + + struct ps_transport_t tcp_transport; + ps_tcp_transport_init(&tcp_transport, &node); + ps_node_add_transport(&node, &tcp_transport); + + struct ps_pub_t string_pub; + ps_node_create_publisher(&node, "/data", &pubsub__PointCloud_def, &string_pub, true); + + // come up with the latched topic + static struct pubsub__PointCloud rmsg; + rmsg.num_points = 100000000;// 10 million points! + rmsg.point_type = pubsub::msg::PointCloud::POINT_XYZ; + rmsg.data_length = rmsg.num_points*4*3;//3 floats per point + rmsg.data = (uint8_t*)malloc(rmsg.data_length); + ps_pub_publish_ez(&string_pub, &rmsg); + + struct ps_sub_t string_sub; + + struct ps_subscriber_options options; + ps_subscriber_options_init(&options); + options.allocator = 0; + options.ignore_local = false; + + static bool got_message = false; + options.preferred_transport = 1; + options.cb_raw = [](void* message, unsigned int size, void* data2, const ps_msg_info_t* info) + { + got_message = true; + printf("Got message\n"); + // todo need to also assert we have the message type + // which is tricky for udp... + auto data = (struct pubsub__PointCloud*)pubsub__PointCloud_decode(message, &ps_default_allocator); + printf("Decoded message\n"); + EXPECT(data->num_points == rmsg.num_points); + pubsub__PointCloud_free(data, &ps_default_allocator); + free(message); + }; + ps_node_create_subscriber_adv(&node, "/data", &pubsub__PointCloud_def, &string_sub, &options); + + // now spin and wait for us to get the published message + while (ps_okay() && !got_message) + { + ps_node_spin(&node);// todo blocking wait first + + ps_sleep(1); + } + +done: + EXPECT(got_message); + ps_node_destroy(&node); +}); + +TEST(test_publish_subscribe_latched_skip, []() { + // test that we still get the latched message even if we want to skip messages + struct ps_node_t node; + ps_node_init(&node, "test_node", "", false); struct ps_transport_t tcp_transport; ps_tcp_transport_init(&tcp_transport, &node); @@ -165,18 +212,16 @@ void latch_test_cb(bool broadcast, bool tcp) struct ps_sub_t string_sub; struct ps_subscriber_options options; ps_subscriber_options_init(&options); - options.preferred_transport = tcp ? 1 : 0;// tcp yo + options.skip = 100; options.cb = [](void* message, unsigned int size, void* cb_data, const struct ps_msg_info_t* info) { auto data = (struct pubsub__String*)message; EXPECT(strcmp(data->value, rmsg.value) == 0); - free(data->value); - free(data);//todo use allocator free + pubsub__String_free(data, &ps_default_allocator); got_message = true; }; ps_node_create_subscriber_adv(&node, "/data", &pubsub__String_def, &string_sub, &options); - // now spin and wait for us to get the published message while (ps_okay() && !got_message) { @@ -187,22 +232,58 @@ void latch_test_cb(bool broadcast, bool tcp) EXPECT(got_message); ps_node_destroy(&node); -} - -TEST(test_publish_subscribe_latched_cb_multicast, []() { - latch_test_cb(false, false); }); -TEST(test_publish_subscribe_latched_cb_broadcast, []() { - latch_test_cb(true, false); -}); +TEST(test_publish_subscribe_skip, []() { + // test that skip works correctly + struct ps_node_t node; + ps_node_init(&node, "test_node", "", false); -TEST(test_publish_subscribe_latched_cb_multicast_tcp, []() { - latch_test_cb(false, true); -}); + struct ps_transport_t tcp_transport; + ps_tcp_transport_init(&tcp_transport, &node); + ps_node_add_transport(&node, &tcp_transport); -TEST(test_publish_subscribe_latched_cb_broadcast_tcp, []() { - latch_test_cb(true, true); + struct ps_pub_t string_pub; + ps_node_create_publisher(&node, "/data", &pubsub__String_def, &string_pub, true); + + // come up with the latched topic + static struct pubsub__String rmsg; + rmsg.value = "Hello"; + ps_pub_publish_ez(&string_pub, &rmsg); + + struct ps_sub_t string_sub; + struct ps_subscriber_options options; + ps_subscriber_options_init(&options); + options.skip = 10; + static int received = 0; + options.cb = [](void* message, unsigned int size, void* cb_data, const struct ps_msg_info_t* info) + { + auto data = (struct pubsub__String*)message; + EXPECT(strcmp(data->value, rmsg.value) == 0); + pubsub__String_free(data, &ps_default_allocator); + received++; + }; + ps_node_create_subscriber_adv(&node, "/data", &pubsub__String_def, &string_sub, &options); + + // first spin and wait for connection + while (ps_okay() && ps_pub_get_subscriber_count(&string_pub) == 0) + { + ps_node_spin(&node); + ps_sleep(1); + } + + // now spin and publish + for (int i = 0; i < 100; i++) + { + ps_node_spin(&node); + ps_pub_publish_ez(&string_pub, &rmsg); + ps_sleep(1); + } + + // finally count the number of messages + EXPECT(received == 10); + + ps_node_destroy(&node); }); CREATE_MAIN_ENTRY_POINT(); diff --git a/tests/test_pubsub_cpp.cpp b/tests/test_pubsub_cpp.cpp index 8967804..334c7ba 100644 --- a/tests/test_pubsub_cpp.cpp +++ b/tests/test_pubsub_cpp.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include "mini_mock.hpp" @@ -15,23 +16,120 @@ TEST(test_publish_subscribe_latched_cpp, []() { pubsub::msg::String omsg; omsg.value = "Hello"; string_pub.publish(omsg); + + pubsub::BlockingSpinnerWithTimers spinner; + spinner.setNode(node); + bool got_message = false; + pubsub::Subscriber subscriber(node, "/data", [&](const pubsub::msg::StringSharedPtr& msg) { + printf("Got message %s in sub1\n", msg->value.c_str()); + EXPECT(omsg.value == msg->value); + got_message = true; + spinner.stop(); + + // make sure subscriber count is correct + EXPECT(string_pub.getNumSubscribers() == 1); + }, 10); + + spinner.run(); + EXPECT(got_message); +}); + +TEST(test_publish_subscribe_zero_copy, []() { + // test that data gets passed through without copying within a single node + pubsub::Node node("simple_publisher"); + + pubsub::Publisher string_pub(node, "/data", true); + + pubsub::msg::StringSharedPtr omsg(new pubsub::msg::String()); + omsg->value = "Hello"; + string_pub.publish(omsg); pubsub::BlockingSpinnerWithTimers spinner; spinner.setNode(node); bool got_message = false; pubsub::Subscriber subscriber(node, "/data", [&](const pubsub::msg::StringSharedPtr& msg) { - printf("Got message %s in sub1\n", msg->value); - EXPECT(strcmp(omsg.value, msg->value) == 0); + printf("Got message %s in sub1\n", msg->value.c_str()); + EXPECT(omsg->value == msg->value); got_message = true; + EXPECT(msg.get() == omsg.get()); spinner.stop(); + + // make sure subscriber count is correct + EXPECT(string_pub.getNumSubscribers() == 1); }, 10); - spinner.wait(); + spinner.run(); EXPECT(got_message); }); +TEST(test_publish_subscribe_queue_behavior, []() { + // test that the queue drops messages as expected and we only get the newest + pubsub::Node node("simple_publisher"); + + pubsub::Publisher int_pub(node, "/data"); + + pubsub::BlockingSpinnerWithTimers spinner; + spinner.setNode(node); + + std::vector received; + pubsub::Subscriber subscriber(node, "/data", [&](const pubsub::msg::IntSharedPtr& msg) { + printf("Got message %li in sub1\n", msg->value); + received.push_back(msg->value); + spinner.stop(); + }, 10); + + for (int i = 0; i < 100; i++) + { + pubsub::msg::Int omsg; + omsg.value = i; + int_pub.publish(omsg); + } + + // spin after publishing so the queue fills up + spinner.run(); + EXPECT(received.size() == 10); + for (int i = 0; i < 10; i++) + { + EXPECT(received[i] == 90 + i); + } +}); + +TEST(test_publish_subscribe_nodelets, []() { + // test that data gets passed through without copying between multiple nodes + pubsub::Node nodep("simple_publisher"); + pubsub::Node nodes("simple_subscriber"); + + pubsub::Publisher string_pub(nodep, "/data", true); + + pubsub::msg::StringSharedPtr omsg(new pubsub::msg::String()); + omsg->value = "Hello"; + string_pub.publish(omsg); + + pubsub::BlockingSpinnerWithTimers spinner; + spinner.setNode(nodep); + + pubsub::BlockingSpinnerWithTimers spinner2; + spinner2.setNode(nodes); + + bool got_message = false; + pubsub::Subscriber subscriber(nodes, "/data", [&](const pubsub::msg::StringSharedPtr& msg) { + printf("Got message %s in sub1\n", msg->value.c_str()); + EXPECT(omsg->value == msg->value); + got_message = true; + EXPECT(msg.get() == omsg.get()); + spinner2.stop(); + spinner.stop(); + }, 10); + + spinner2.start(); + spinner.run(); + spinner2.wait(); + EXPECT(got_message); +}); +//okay, things to add, fixed length strings so we can have strings in structs +//should also enhance bindings for strings, I think C++ strings have memory leaks // Make sure close works on publishers/subscribers and doesnt result in them getting closed multiple times TEST(test_publisher_subscriber_close_cpp, []() { pubsub::Node node("simple_publisher"); @@ -45,6 +143,7 @@ TEST(test_publisher_subscriber_close_cpp, []() { sub.close(); sub.close(); }); + TEST(test_publish_subscribe_cpp, []() { // test that normal messages make it through message passing pubsub::Node node("simple_publisher"); @@ -59,8 +158,8 @@ TEST(test_publish_subscribe_cpp, []() { bool got_message = false; pubsub::Subscriber subscriber(node, "/data", [&](const pubsub::msg::StringSharedPtr& msg) { - printf("Got message %s in sub1\n", msg->value); - EXPECT(strcmp(omsg.value, msg->value) == 0); + printf("Got message %s in sub1\n", msg->value.c_str()); + EXPECT(omsg.value == msg->value); spinner.stop(); got_message = true; }, 10); @@ -70,8 +169,79 @@ TEST(test_publish_subscribe_cpp, []() { string_pub.publish(omsg); }); - spinner.wait(); + spinner.run(); + EXPECT(got_message); +}); + +// tracking allocator usage +static int allocated; +static int freed; +static std::map sizes; +static ps_allocator_t alloc; +struct TestAllocator +{ + static ps_allocator_t* allocator() { + return &alloc; + } + + static void* Allocate(uint32_t size, void* context) + { + auto ptr = malloc(size); + sizes[ptr] = size; + allocated += size; + printf("allocate %i\n", allocated); + return ptr; + } + + static void Free(void* ptr, void* context) + { + freed += sizes[ptr]; + printf("free %i\n", freed); + free(ptr); + } + + static void Setup() + { + allocated = 0; + freed = 0; + alloc.context = 0; + alloc.free = Free; + alloc.alloc = Allocate; + } +}; + +TEST(test_publish_subscribe_allocator_cpp, []() { + TestAllocator::Setup(); + // test that allocators are used with C++ + pubsub::Node node("simple_publisher"); + bool got_message = false; + { + pubsub::Publisher> string_pub(node, "/data"); + + pubsub::msg::String_ omsg; + omsg.value = "Hello"; + + pubsub::BlockingSpinnerWithTimers spinner; + spinner.setNode(node); + + pubsub::Subscriber> subscriber(node, "/data", [&](const pubsub::msg::String_::SharedPtr& msg) { + printf("Got message %s in sub1\n", msg->value.c_str()); + EXPECT(omsg.value == msg->value); + spinner.stop(); + got_message = true; + }, 10); + + spinner.addTimer(0.1, [&]() + { + string_pub.publish(omsg); + }); + + spinner.run(); + } EXPECT(got_message); + + EXPECT(allocated == 20); + EXPECT(allocated == freed); }); CREATE_MAIN_ENTRY_POINT(); diff --git a/tests/test_serialization.cpp b/tests/test_serialization.cpp index eda4f6f..a99e61a 100644 --- a/tests/test_serialization.cpp +++ b/tests/test_serialization.cpp @@ -2,13 +2,14 @@ #include #include #include +#include +#include #include "mini_mock.hpp" TEST(test_joy_serialization, []() { - // try serializing then deserializing a message to make sure it all matches - pubsub::msg::Joy msg; + pubsub::msg::Joy msg; msg.buttons = 0x12345678; for (int i = 0; i < 8; i++) msg.axes[i] = i; @@ -26,10 +27,70 @@ TEST(test_joy_serialization, []() { free(out); }); -void test2() -{ +TEST(test_string_cpp, []() { + std::string value = "hi"; + // test C++ strings + { + pubsub::msg::String msg; + EXPECT(msg.value.data() == 0) + EXPECT(msg.value == ""); + msg.value = "apples"; + EXPECT(msg.value == "apples"); + EXPECT(msg.value == std::string("apples")); + EXPECT(strcmp(msg.value.c_str(), "apples") == 0); + msg.value = value; + EXPECT(msg.value == value); + EXPECT(msg.value == value.c_str()); + EXPECT(strcmp(msg.value.c_str(), value.c_str()) == 0); + + // make sure the C++ string object is the same size as a pointer + EXPECT(sizeof(msg.value) == sizeof(char*)); + } +}); + +TEST(test_fixed_string_cpp, []() { + // test C++ fixed strings + { + pubsub::FixedString<5> string; + string = "hi"; + + EXPECT(string == "hi"); + EXPECT(string == std::string("hi")); + EXPECT(strcmp(string.c_str(), "hi") == 0); + EXPECT(sizeof(string) == 5); + } + + // test what happens when you assign too much + { + pubsub::FixedString<5> string; + EXPECT_THROWS([&](){ + string = "hello paul"; + }, "Too big."); + } +}); + +TEST(test_array_vector_cpp, []() { + // test C++ array vectors + { + pubsub::ArrayVector vector; + EXPECT(sizeof(vector) == 4 + sizeof(char*)); + EXPECT(vector.size() == 0); + vector.resize(4); + EXPECT(vector.size() == 4) + + for (auto& item: vector) + { + item = 55; + } + + EXPECT(vector[0] == 55); + EXPECT(vector[3] == 55); + } +}); + +TEST(test_costmap_c_cpp, []() { // verify that the C type matches the C++ one memory wise - pubsub::msg::Costmap msg; + pubsub::msg::Costmap msg; msg.width = 100; msg.height = 200; msg.resolution = 1.0; @@ -49,15 +110,11 @@ void test2() if (msg.data[i] != cmsg->data[i]) EXPECT(false); } -} - -TEST(test_costmap_c_cpp, []() { - test2(); }); TEST(test_costmap_serialization, []() { // try serializing then deserializing a message to make sure it all matches - pubsub::msg::Costmap msg; + pubsub::msg::Costmap msg; msg.width = 100; msg.height = 200; msg.resolution = 1.0; @@ -83,12 +140,12 @@ TEST(test_costmap_serialization, []() { if (out->data[i] != msg.data[i]) EXPECT(false); } - delete out; + delete out; }); TEST(test_path2d_serialization, []() { // try serializing then deserializing a message to make sure it all matches - pubsub::msg::Path2D msg; + pubsub::msg::Path2D msg; msg.frame = 100; msg.points.resize(123); for (int i = 0; i < msg.points.size(); i++) @@ -114,9 +171,33 @@ TEST(test_path2d_serialization, []() { delete out; }); +TEST(test_path2d_copy, []() { + // make sure copying a message works + pubsub::msg::Path2D msg; + msg.frame = 100; + msg.points.resize(123); + for (int i = 0; i < msg.points.size(); i++) + { + msg.points[i].x = i*2; + msg.points[i].y = i*2 + 1; + } + + pubsub::msg::Path2D msg2 = msg; + + EXPECT(msg2.frame == msg.frame); + EXPECT(msg2.points.size() == msg.points.size()); + for (int i = 0; i < msg.points.size(); i++) + { + if (msg2.points[i].x != msg.points[i].x) + EXPECT(false); + if (msg2.points[i].y != msg.points[i].y) + EXPECT(false); + } +}); + TEST(test_path2d_foreach, []() { // try serializing then deserializing a message, making sure the foreach loop over it works as expected - pubsub::msg::Path2D msg; + pubsub::msg::Path2D msg; msg.frame = 100; msg.points.resize(3); for (int i = 0; i < msg.points.size(); i++) @@ -143,4 +224,144 @@ TEST(test_path2d_foreach, []() { delete out; }); +void compare_struct(const pubsub::msg::Test::Struct& a, const pubsub::msg::Test::Struct& b) +{ + EXPECT(a.test_array_string == b.test_array_string); + EXPECT(a.i1 == b.i1); + EXPECT(a.i2 == b.i2); + EXPECT(a.f3 == b.f3); + EXPECT(a.d4 == b.d4); + for (int i = 0; i < 3; i++) + EXPECT(a.a5[i] == b.a5[i]); +} + +void fill_struct(pubsub::msg::Test::Struct& a) +{ + a.test_array_string = "ok"; + a.i1 = rand(); + a.i2 = rand(); + a.f3 = rand()/100.0; + a.d4 = rand()/100.0; + for (int i = 0; i < 3; i++) + a.a5[i] = rand()/100.0; + a.test_enum = pubsub::msg::Test::TYPE_INT; +} + +TEST(test_complex_message, []() { + // try serializing then deserializing a message, making sure everything matches expectation + EXPECT(sizeof(pubsub::msg::Test) == sizeof(pubsub__Test)); + EXPECT(sizeof(pubsub::msg::Test::Struct) == 56); + pubsub::msg::Test msg; + msg.test_int = rand(); + msg.test_string = "THIS IS A STRING"; + fill_struct(msg.test_struct); + fill_struct(msg.test_structs[0]); + fill_struct(msg.test_structs[1]); + msg.test_struct_array.resize(4); + for (int i = 0; i < 4; i++) + fill_struct(msg.test_struct_array[i]); + msg.test_bitmask[0] = pubsub::msg::Test::TYPE2_FLAG1; + msg.test_bitmask[1] = pubsub::msg::Test::TYPE2_FLAG2; + msg.test_bitmask[2] = pubsub::msg::Test::TYPE2_FLAG1 | pubsub::msg::Test::TYPE2_FLAG2; + + ps_msg_t in = msg.Encode(); + EXPECT(in.len == 420);// 56*7 + 4 + 4 + 3 + 17 = 392 + 11 + 17 + + // Make sure deserialize iterators work correctly + struct ps_deserialize_iterator iter = ps_deserialize_start(ps_get_msg_start((const char*)in.data), pubsub::msg::Test::GetDefinition()); + const struct ps_msg_field_t* field; uint32_t length; const void* ptr; + std::vector fields; + while (ptr = ps_deserialize_iterate(&iter, &field, &length)) + { + fields.push_back(field->name); + + // check some values + if (strcmp(field->name, "test_bitmask") == 0) + { + EXPECT(*(uint8_t*)ptr == msg.test_bitmask[0]); + } + if (strcmp(field->name, "test_int") == 0) + { + EXPECT(*(uint32_t*)ptr == msg.test_int); + } + } + EXPECT(fields.size() == 6); + EXPECT(fields[0] == "test_int"); + EXPECT(fields[5] == "test_bitmask"); + + auto* out = pubsub::msg::Test::Decode(ps_get_msg_start(in.data)); + free(in.data); + + EXPECT(msg.test_int == out->test_int); + EXPECT(msg.test_string == out->test_string); + compare_struct(msg.test_struct, out->test_struct); + compare_struct(msg.test_structs[0], out->test_structs[0]); + compare_struct(msg.test_structs[1], out->test_structs[1]); + EXPECT(msg.test_struct_array.size() == out->test_struct_array.size()) + for (int i = 0; i < 4; i++) + compare_struct(msg.test_struct_array[i], out->test_struct_array[i]); + for (int i = 0; i < 3; i++) + EXPECT(msg.test_bitmask[i] == out->test_bitmask[i]); +}); + + +// tracking allocator usage +static int allocated; +static int freed; +static std::map sizes; +static ps_allocator_t alloc; +struct TestAllocator +{ + static ps_allocator_t* allocator() { + return &alloc; + } + + static void* Allocate(uint32_t size, void* context) + { + auto ptr = malloc(size); + sizes[ptr] = size; + allocated += size; + printf("allocate %i\n", allocated); + return ptr; + } + + static void Free(void* ptr, void* context) + { + freed += sizes[ptr]; + printf("free %i\n", freed); + free(ptr); + } + + static void Setup() + { + allocated = 0; + freed = 0; + alloc.context = 0; + alloc.free = Free; + alloc.alloc = Allocate; + } +}; + +TEST(test_message_allocators, []() { + // make sure the allocator gets used + TestAllocator::Setup(); + + auto msg = new pubsub::msg::Costmap_(); + msg->data.resize(1000); + + delete msg; + + EXPECT(allocated == 1041); + EXPECT(freed > 0); + EXPECT(freed == allocated); + + auto msg2 = new pubsub::msg::String_(); + msg2->value = "hi"; + + delete msg2; + + EXPECT(allocated == 1041+4+3+4); + EXPECT(freed == allocated); +}); + CREATE_MAIN_ENTRY_POINT(); diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index fea1c57..cf7315a 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -53,7 +53,7 @@ macro(generate_messages) target_include_directories(${ARG_TARGET} INTERFACE ${CMAKE_CURRENT_BINARY_DIR}/include/) endmacro() -generate_messages(FILES +generate_messages(FILES Joy.msg String.msg Float.msg @@ -69,6 +69,15 @@ generate_messages(FILES Log.msg Path2D.msg Odom.msg + Int.msg +) + +generate_messages(FILES + Test.msg +TARGET + pubsub_test_msgs +DIRECTORY + tests ) diff --git a/tools/PubSubTest.cpp b/tools/PubSubTest.cpp index 05cb4ee..cc77a3c 100644 --- a/tools/PubSubTest.cpp +++ b/tools/PubSubTest.cpp @@ -23,7 +23,13 @@ int main() ps_node_create_publisher(&node, "/joy", &pubsub__Joy_def, &adv_pub, false); ps_sub_t string_sub; - ps_node_create_subscriber(&node, "/data", &pubsub__String_def, &string_sub, 10, 0, false); + auto cb = [](void* message, unsigned int size, void* cbdata, const ps_msg_info_t* info) + { + pubsub__String* data = (pubsub__String*)message; + free(data->value); + free(data);//todo use allocator free + }; + ps_node_create_subscriber(&node, "/data", &pubsub__String_def, &string_sub, cb, 0, 0, false); // wait until we get the subscription request while (ps_pub_get_subscriber_count(&string_pub) == 0) @@ -61,13 +67,6 @@ int main() ps_node_spin(&node); //while (ps_node_spin(&node) == 0) { Sleep(1); } - // our sub has a message definition, so the queue contains real messages - while (pubsub__String* data = (pubsub__String*)ps_sub_deque(&string_sub)) - { - printf("Got message: %s\n", data->value); - free(data->value); - free(data);//todo use allocator free - } printf("Num subs: %i %i\n", ps_pub_get_subscriber_count(&string_pub), ps_pub_get_subscriber_count(&adv_pub)); ps_sleep(1000); diff --git a/tools/generator.cpp b/tools/generator.cpp index 7484e63..5e9b23c 100644 --- a/tools/generator.cpp +++ b/tools/generator.cpp @@ -4,6 +4,7 @@ #include +#include #include #include #include @@ -26,16 +27,16 @@ struct enumeration { std::string name; std::string value; - int field_num; }; +struct field; struct Type { std::string name; std::string base_type; std::string type_enum; // if zero size, a basic type - std::vector> fields; + std::vector fields; }; // stupid hack @@ -48,16 +49,20 @@ struct field int array_size; + int string_size; + std::string flag; uint32_t line_number; + + std::vector associated_enums; - std::string getBaseType() + std::string getBaseType() const { return type->base_type; } - std::string getFlags() + std::string getFlags() const { if (flag == "enum") return "FF_ENUM"; @@ -71,12 +76,12 @@ struct field return "invalid"; } - std::string getTypeEnum() + std::string getTypeEnum() const { return type->type_enum; } - void GenerateFree(std::string& output) + void GenerateFree(std::string& output) const { if (type->name == "string" && array_size != 1) { @@ -100,7 +105,7 @@ struct field } } - void GenerateCopy(std::string& output, const std::string& source) + void GenerateCopy(std::string& output, const std::string& source) const { if (type->name == "string") { @@ -223,8 +228,10 @@ std::string generate(const char* definition, const char* name) types["int8"] = new Type{"int8", "int8_t", "FT_Int8", {}}; types["float"] = new Type{"float", "float", "FT_Float32", {}}; types["double"] = new Type{"double", "double", "FT_Float64", {}}; + types["astring"] = new Type{"astring", "char", "FT_ArrayString", {}}; // also generate the hash while we are at it + std::vector unassociated_enums; uint32_t hash = 0; uint32_t line_number = 0; for (auto& line : lines) @@ -294,8 +301,23 @@ std::string generate(const char* definition, const char* name) current_struct->name = name; current_struct->base_type = name; current_struct->type_enum = "FT_Struct"; + unassociated_enums.clear();// if they come before a struct defintion we arent sure what they go with continue; } + + int string_size = 0; + if (type.find(':') != -1) + { + int index = type.find(':'); + string_size = std::stoi(type.substr(index + 1)); + type = type.substr(0, index); + + if (type.find("astring") == -1) + { + printf("%s:%i ERROR: The element count syntax can only be used with astrings'.\n", _current_file.c_str(), line_number); + throw 7; + } + } // lookup the type Type* real_type = 0; @@ -313,11 +335,13 @@ std::string generate(const char* definition, const char* name) if (current_struct) { // this belongs in the struct - current_struct->fields.push_back({real_type, name}); + current_struct->fields.push_back(new field{ name, real_type, size, string_size, "", line_number, unassociated_enums }); + unassociated_enums.clear(); continue; } // also fill in array size - fields.push_back({ name, real_type, size, "", line_number }); + fields.push_back({ name, real_type, size, string_size, "", line_number, unassociated_enums }); + unassociated_enums.clear(); } // a line with flags maybe? else if (words.size() == 3 && !has_equal) @@ -352,8 +376,17 @@ std::string generate(const char* definition, const char* name) throw 7; } + if (current_struct) + { + // this belongs in the struct + current_struct->fields.push_back(new field{ name, real_type, size, 0, flag, line_number, unassociated_enums }); + unassociated_enums.clear(); + continue; + } + // also fill in array size - fields.push_back({ name, real_type, size, flag, line_number}); + fields.push_back({ name, real_type, size, 0, flag, line_number, unassociated_enums}); + unassociated_enums.clear(); } else { @@ -366,7 +399,9 @@ std::string generate(const char* definition, const char* name) std::string name = strip_whitespace(equals[0]); std::string value = strip_whitespace(equals[1]); //printf("it was an enum: %s=%s\n", equals[0].c_str(), equals[1].c_str()); - enumerations.push_back({ name, value, (int)fields.size()}); + + enumerations.push_back({ name, value }); + unassociated_enums.push_back(enumerations.size() - 1); } else { @@ -377,7 +412,7 @@ std::string generate(const char* definition, const char* name) } std::string raw_name = split(name, '_').back(); - std::string ns = std::string(name).substr(0, std::string(name).find_last_of('_')-1); + std::string ns = std::string(name).substr(0, std::string(name).find_last_of('_')-1); // convert the name into a type std::string type_name; @@ -419,19 +454,30 @@ std::string generate(const char* definition, const char* name) for (auto& field: type.second->fields) { // dont allow strings yet - if (field.first->base_type == "char*") + if (field->type->base_type == "char*") { - printf("%s:%i ERROR: Strings not yet allowed in structs.\n", _current_file.c_str(), line_number); + printf("%s:%i ERROR: Dynamicly sized strings not allowed in structs.\n", _current_file.c_str(), line_number); throw 7; } - if (field.first->type_enum == "FT_Struct") + if (field->type->type_enum == "FT_Struct") { printf("%s:%i ERROR: Structs not yet allowed in structs.\n", _current_file.c_str(), line_number); throw 7; } - - output += " " + field.first->base_type + " " + field.second + ";\n"; + + if (field->type == types["astring"]) + { + output += " " + field->type->base_type + " " + field->name + "[" + std::to_string(field->string_size) + "];\n"; + } + else if (field->array_size > 1) + { + output += " " + field->type->base_type + " " + field->name + "[" + std::to_string(field->array_size) + "];\n"; + } + else + { + output += " " + field->type->base_type + " " + field->name + ";\n"; + } } output += "};\n\n"; } @@ -462,38 +508,88 @@ std::string generate(const char* definition, const char* name) //ps_message_definition_t std_msgs_String_def = { 123456789, "std_msgs/String", 1, std_msgs_String_fields }; // generate the fields - output += "struct ps_msg_field_t " + type_name + "_fields[] = {\n"; - for (auto& field : fields) + output += "static struct ps_msg_field_t " + type_name + "_fields[] = {\n"; + std::map generated_structs; + std::map field_indexes; + int field_index = 0; + for (const auto& field : fields) { if (field.getTypeEnum() == "FT_Struct") { - // struct - auto& members = field.type->fields; + // add the struct if we haven't already + int struct_index; + if (generated_structs.find(field.type->name) == generated_structs.end()) + { + struct_index = field_index; + generated_structs[field.type->name] = struct_index; + + auto& members = field.type->fields; + // add the struct itself + output += " { FT_StructDefinition, FF_NONE, \"" + field.type->name + "\", "; + output += std::to_string(members.size()) + ", 0 }, \n"; + field_index++; + + // now add struct fields + for (auto& m : members) + { + output += " { " + m->type->type_enum + ", " + m->getFlags() + ", \"" + m->name + "\", "; + output += std::to_string(m->array_size) + ", " + std::to_string(m->string_size) + " }, \n"; + field_indexes[m] = field_index++; + } + } + else + { + struct_index = generated_structs[field.type->name]; + } + + // add the field output += " { " + field.getTypeEnum() + ", " + field.getFlags() + ", \"" + field.name + "\", "; - output += std::to_string(field.array_size) + ", " + std::to_string(members.size()) + " }, \n";// todo use for array types - - // now add struct fields - for (auto& m : members) - { - output += " { " + m.first->type_enum + ", " + "FF_NONE" + ", \"" + m.second + "\", "; - output += std::to_string(1) + ", 0 }, \n";// todo support array members - } + output += std::to_string(field.array_size) + ", " + std::to_string(struct_index) + " }, \n"; } else { + // add the field output += " { " + field.getTypeEnum() + ", " + field.getFlags() + ", \"" + field.name + "\", "; - output += std::to_string(field.array_size) + ", 0 }, \n";// todo use for array types + output += std::to_string(field.array_size) + ", " + std::to_string(field.string_size) + " }, \n"; } + field_indexes[&field] = field_index++; } output += "};\n\n"; // generate enum metadata if (enumerations.size()) { - output += "struct ps_msg_enum_t " + type_name + "_enums[] = {\n"; - for (auto& e: enumerations) + output += "static struct ps_msg_enum_t " + type_name + "_enums[] = {\n"; + int enum_id = 0; + for (const auto& e: enumerations) { - output += " {\"" + e.name + "\", " + e.value + ", " + std::to_string(e.field_num) + "},\n"; + // okay, lets flip the script, each field lists associated enums? + // now search for the associated field + int field_num = 255; + for (const auto& field: fields) + { + for (const auto& sf: field.type->fields) + { + for (auto id: sf->associated_enums) + { + if (id == enum_id) + { + field_num = field_indexes[sf]; + break; + } + } + } + for (auto id: field.associated_enums) + { + if (id == enum_id) + { + field_num = field_indexes[&field]; + break; + } + } + } + output += " {\"" + e.name + "\", " + e.value + ", " + std::to_string(field_num) + "},\n"; + enum_id++; } output += "};\n\n"; } @@ -536,24 +632,29 @@ std::string generate(const char* definition, const char* name) if (is_pure) { //generate simple de/serializaton - output += "void* " + type_name + "_decode(const void* data, struct ps_allocator_t* allocator)\n{\n"; + output += "static void* " + type_name + "_decode(const void* data, struct ps_allocator_t* allocator)\n{\n"; output += " struct " + type_name + "* out = (struct " + type_name + "*)allocator->alloc(sizeof(struct " + type_name + "), allocator->context);\n"; output += " *out = *(struct " + type_name + "*)data;\n"; output += " return out;\n"; output += "}\n\n"; // now for encode - output += "struct ps_msg_t " + type_name + "_encode(struct ps_allocator_t* allocator, const void* msg)\n{\n"; + output += "static struct ps_msg_t " + type_name + "_encode(const void* msg, struct ps_allocator_t* allocator)\n{\n"; output += " int len = sizeof(struct " + type_name + ");\n"; output += " struct ps_msg_t omsg;\n"; output += " ps_msg_alloc(len, allocator, &omsg);\n"; output += " memcpy(ps_get_msg_start(omsg.data), msg, len);\n"; - output += " return omsg;\n}\n"; + output += " return omsg;\n}\n\n"; + + // finally free todo use allocator + output += "static void " + type_name + "_free(void* msg, struct ps_allocator_t* allocator)\n{\n"; + output += " allocator->free(msg, allocator->context);\n"; + output += "}\n\n"; } else { //need to split it in sections between the strings - output += "void* " + type_name + "_decode(const void* data, struct ps_allocator_t* allocator)\n{\n"; + output += "static void* " + type_name + "_decode(const void* data, struct ps_allocator_t* allocator)\n{\n"; output += " char* p = (char*)data;\n"; output += " int len = sizeof(struct "+type_name+");\n"; output += " struct "+type_name+"* out = (struct " + type_name + "*)allocator->alloc(len, allocator->context);\n"; @@ -582,7 +683,7 @@ std::string generate(const char* definition, const char* name) } output += " out->" + fields[i].name + "_length = num_" + fields[i].name + ";\n"; - output += " out->" + fields[i].name + " = (char**)malloc(sizeof(char*)*num_" + fields[i].name + ");\n"; + output += " out->" + fields[i].name + " = (char**)allocator->alloc(sizeof(char*)*num_" + fields[i].name + ", allocator->context);\n"; // allocate the array // need to do it! @@ -591,7 +692,7 @@ std::string generate(const char* definition, const char* name) output += " int len = *(uint32_t*)p;\n"; output += " p += 4;\n";// add size of length // now read and allocate each string - output += " out->" + fields[i].name + "[i] = (char*)malloc(len);\n"; + output += " out->" + fields[i].name + "[i] = (char*)allocator->alloc(len, allocator->context);\n"; output += " memcpy(out->" + fields[i].name + "[i], p, len);\n"; output += " p += len;\n"; output += " }\n"; @@ -627,7 +728,7 @@ std::string generate(const char* definition, const char* name) output += "}\n\n"; //typedef ps_msg_t(*ps_fn_encode_t)(ps_allocator_t* allocator, const void* msg); - output += "struct ps_msg_t " + type_name + "_encode(struct ps_allocator_t* allocator, const void* data)\n{\n"; + output += "static struct ps_msg_t " + type_name + "_encode(const void* data, struct ps_allocator_t* allocator)\n{\n"; output += " const struct " + type_name + "* msg = (const struct " + type_name + "*)data;\n"; output += " int len = sizeof(struct " + type_name + ");\n"; output += " // calculate the encoded length of the message\n"; @@ -729,22 +830,53 @@ std::string generate(const char* definition, const char* name) } output += " return omsg;\n"; output += "}\n"; + + // finally free + output += "static void " + type_name + "_free(void* data, struct ps_allocator_t* allocator)\n{\n"; + output += " struct " + type_name + "* msg = (struct " + type_name + "*)data;\n"; + for (size_t i = 0; i < fields.size(); i++) + { + if (fields[i].type == string_type) + { + if (fields[i].array_size == 1) + { + output += " allocator->free(msg->" + fields[i].name + ", allocator->context);\n"; + } + else + { + if (fields[i].array_size == 0) + { + output += " int num_" + fields[i].name + " = msg->" + fields[i].name + "_length;\n"; + } + else + { + output += " int num_" + fields[i].name + " = " + std::to_string(fields[i].array_size) + ";\n"; + } + + output += " for (int i = 0; i < num_" + fields[i].name + "; i++) {\n"; + output += " allocator->free(msg->" + fields[i].name + "[i], allocator->context);\n"; + output += " }\n"; + output += " allocator->free(msg->" + fields[i].name + ", allocator->context);\n"; + } + } + else if (fields[i].array_size == 0) + { + output += " allocator->free(msg->" + fields[i].name + ", allocator->context);\n"; + } + } + output += " allocator->free(msg, allocator->context);\n"; + output += "}\n\n"; } // generate the actual message definition - int field_count = fields.size(); - for (auto& f: fields) - { - field_count += f.type->fields.size(); - } - output += "struct ps_message_definition_t " + type_name + "_def = { "; + output += "static struct ps_message_definition_t " + type_name + "_def = { "; if (enumerations.size() == 0) { - output += std::to_string(hash) + ", \"" + name + "\", " + std::to_string(field_count) + ", " + type_name + "_fields, " + type_name + "_encode, " + type_name + "_decode, 0, 0 };\n"; + output += std::to_string(hash) + ", \"" + name + "\", " + std::to_string(field_index) + ", " + type_name + "_fields, " + type_name + "_encode, " + type_name + "_decode, " + type_name + "_free, 0, 0 };\n"; } else { - output += std::to_string(hash) + ", \"" + name + "\", " + std::to_string(field_count) + ", " + type_name + "_fields, " + type_name + "_encode, " + type_name + "_decode, " + std::to_string(enumerations.size()) + ", " + type_name + "_enums };\n"; + output += std::to_string(hash) + ", \"" + name + "\", " + std::to_string(field_index) + ", " + type_name + "_fields, " + type_name + "_encode, " + type_name + "_decode, " + type_name + "_free, " + std::to_string(enumerations.size()) + ", " + type_name + "_enums };\n"; } output += "\n#ifdef __cplusplus\n"; @@ -752,22 +884,60 @@ std::string generate(const char* definition, const char* name) output += "#include \n"; output += "#include \n"; //output += "#include \n"; + // todo only include if needed + output += "#include \n"; output += "#include \n"; + output += "#include \n"; output += "namespace " + ns + "\n{\n"; output += "namespace msg\n{\n"; output += "#pragma pack(push, 1)\n"; - output += "struct " + raw_name + "\n{\n"; + output += "template \n"; + output += "struct " + raw_name + "_\n{\n"; + output += " typedef std::shared_ptr<" + raw_name + "_> SharedPtr;\n"; + output += " typedef std::shared_ptr> SharedConstPtr;\n"; + output += " typedef AllocatorT Allocator;\n"; + // generate internal structs + for (auto& type: types) + { + if (type.second->type_enum != "FT_Struct") + { + continue; + } + + output += " struct " + type.second->name + "\n {\n"; + for (auto& field: type.second->fields) + { + if (field->type == types["astring"]) + { + output += " pubsub::FixedString<" + std::to_string(field->string_size) + "> " + field->name + ";\n"; + } + else if (field->array_size > 1) + { + output += " " + field->type->base_type + " " + field->name + "[" + std::to_string(field->array_size) +"];\n"; + } + else + { + output += " " + field->type->base_type + " " + field->name + ";\n"; + } + } + type.second->base_type = type.second->name; + output += " };\n\n"; + } for (auto f: fields) { std::string type = f.type == string_type ? "char*" : f.getBaseType(); - if (f.array_size == 1) + if (f.type == string_type && f.array_size == 1) + { + output += " pubsub::CString " + f.name + ";\n"; + } + else if (f.array_size == 1) { output += " " + type + " " + f.name + ";\n"; } else if (f.array_size == 0) { - output += " ArrayVector<" + type + "> " + f.name + ";\n"; + output += " pubsub::ArrayVector<" + type + ", Allocator> " + f.name + ";\n"; } else { @@ -808,24 +978,28 @@ std::string generate(const char* definition, const char* name) output += " void* operator new(size_t size)\n"; output += " {\n"; //output += " std::cout<< \"Overloading new operator with size: \" << size << std::endl;\n"; - output += " return malloc(size);\n"; + output += " return Allocator::allocator()->alloc(size, Allocator::allocator()->context);\n"; + //output += " return malloc(size);\n"; output += " }\n\n"; output += " void operator delete(void * p)\n"; output += " {\n"; //output += " std::cout<< \"Overloading delete operator \" << std::endl;\n"; - output += " free(p);\n"; + output += " Allocator::allocator()->free(p, Allocator::allocator()->context);\n"; + //output += " free(p);\n"; output += " }\n\n"; output += " static const ps_message_definition_t* GetDefinition()\n {\n"; output += " return &" + type_name + "_def;\n }\n\n"; output += " ps_msg_t Encode() const\n {\n"; - output += " return " + ns + "__" + raw_name + "_encode(&ps_default_allocator, this);\n }\n\n"; - output += " static " + raw_name + "* Decode(const void* data)\n {\n"; - output += " return (" + raw_name + "*)" + ns + "__" + raw_name + "_decode(data, &ps_default_allocator);\n }\n";// + ns + "__" + raw_name + "_encode(0, this);\n }\n"; + output += " return " + ns + "__" + raw_name + "_encode(this, Allocator::allocator());\n }\n\n"; + output += " static " + raw_name + "_* Decode(const void* data)\n {\n"; + output += " return (" + raw_name + "_*)" + ns + "__" + raw_name + "_decode(data, Allocator::allocator());\n }\n";// + ns + "__" + raw_name + "_encode(0, this);\n }\n"; output += "};\n"; + output += "typedef " + raw_name + "_<> " + raw_name + ";\n"; + output += "typedef std::shared_ptr<" + raw_name + "_<>> " + raw_name + "SharedPtr;\n"; + output += "typedef std::shared_ptr> " + raw_name + "SharedConstPtr;\n"; output += "#pragma pack(pop)\n"; - output += "typedef std::shared_ptr<" + raw_name + "> " + raw_name + "SharedPtr;\n"; output += "}\n"; output += "}\n"; diff --git a/tools/pubsub.cpp b/tools/pubsub.cpp index 26a4ef4..04d6db1 100644 --- a/tools/pubsub.cpp +++ b/tools/pubsub.cpp @@ -124,6 +124,13 @@ int topic_info(int num_args, char** args, ps_node_t* node) std::cout << "Type: " << info->second.type << "\n"; std::cout << "Latched: " << (((info->second.flags & PS_ADVERTISE_LATCHED) != 0) ? "True\n" : "False\n"); + int recommended_transport = ((info->second.flags & 0b111110) >> 1); + std::string transport = "UNKNOWN (" + std::to_string(recommended_transport) + ")"; + if (recommended_transport == 0) + transport = "UDP"; + else if (recommended_transport == 1) + transport = "TCP"; + std::cout << "Recommended Transport: " << transport << "\n"; std::cout << "Published by:\n"; for (auto pub : info->second.publishers) { @@ -211,6 +218,7 @@ int topic_echo(int num_args, char** args, ps_node_t* _node) parser.AddOption({ "n" }, "Number of messages to echo.", "0"); parser.AddOption({ "skip", "s" }, "Skip factor for the subscriber.", "0"); parser.AddFlag({ "tcp" }, "Prefer the TCP transport."); + parser.AddFlag({ "udp" }, "Prefer the UDP transport."); parser.AddFlag({ "no-arr" }, "Don't print out the contents of arrays in messages."); parser.AddOption({ "f", "field" }, "Print out just the value of a specific field."); @@ -223,6 +231,12 @@ int topic_echo(int num_args, char** args, ps_node_t* _node) return 0; } + if (parser.GetBool("tcp") && parser.GetBool("udp")) + { + printf("ERROR: Cannot provide both --tcp and --udp options.\n"); + exit(2); + } + static bool print_info = parser.GetBool("i"); double vn = parser.GetDouble("n"); if (vn <= 0) @@ -231,7 +245,6 @@ int topic_echo(int num_args, char** args, ps_node_t* _node) } static unsigned long long int n = vn; int skip = parser.GetDouble("s"); - bool tcp = parser.GetBool("tcp"); static bool no_arr = parser.GetBool("no-arr"); @@ -307,11 +320,12 @@ int topic_echo(int num_args, char** args, ps_node_t* _node) struct ps_subscriber_options options; ps_subscriber_options_init(&options); options.skip = skip; - options.queue_size = 0; options.allocator = 0; options.ignore_local = false; - options.preferred_transport = tcp ? 1 : 0; - options.cb = [](void* message, unsigned int size, void* data, const ps_msg_info_t* info) + options.preferred_transport = -1; + options.preferred_transport = parser.GetBool("tcp") ? 1 : options.preferred_transport; + options.preferred_transport = parser.GetBool("udp") ? 0 : options.preferred_transport; + options.cb_raw = [](void* message, unsigned int size, void* data, const ps_msg_info_t* info) { // get and deserialize the messages if (sub.received_message_def.fields == 0) @@ -342,7 +356,7 @@ int topic_echo(int num_args, char** args, ps_node_t* _node) } ps_deserialize_print(message, &sub.received_message_def, no_arr ? 10 : 0, field_name); printf("-------------\n"); - free(message); + free(message);// todo use allocator if (++count >= n) { // need to commit sudoku here.. @@ -478,7 +492,7 @@ int topic_pub(int num_args, char** args, ps_node_t* node) } // do initial publish - ps_msg_t cpy = ps_msg_cpy(&msg); + ps_msg_t cpy = ps_msg_cpy(&msg, 0); ps_pub_publish(&pub, &cpy); break; } @@ -502,7 +516,7 @@ int topic_pub(int num_args, char** args, ps_node_t* node) ps_node_spin(node); if (rate != 0 && remaining < pubsub::Duration(0.0)) { - ps_msg_t cpy = ps_msg_cpy(&msg); + ps_msg_t cpy = ps_msg_cpy(&msg, 0); ps_pub_publish(&pub, &cpy); next = next + pubsub::Duration(1.0/rate); } @@ -826,12 +840,19 @@ int main(int num_args_real, char** args) pubsub::ArgParser parser; parser.AddOption({ "w", "window" }, "Window size for averaging.", "100"); parser.AddFlag({ "tcp" }, "Prefer the TCP transport."); + parser.AddFlag({ "udp" }, "Prefer the UDP transport."); if (subverb == "hz") parser.SetUsage("Usage: info topic hz TOPIC\n\nDetermines the rate of publication for a given topic."); else parser.SetUsage("Usage: info topic bw TOPIC\n\nDetermines the single subscriber bandwidth for a given topic."); parser.Parse(args, num_args, 2); + if (parser.GetBool("tcp") && parser.GetBool("udp")) + { + printf("ERROR: Cannot provide both --tcp and --udp options.\n"); + exit(2); + } + // create a subscriber ps_sub_t sub; std::vector todo_msgs; @@ -867,9 +888,11 @@ int main(int num_args_real, char** args) ps_subscriber_options opts; ps_subscriber_options_init(&opts); - opts.cb = cb; + opts.cb_raw = cb; opts.cb_data = &message_times; - opts.preferred_transport = parser.GetBool("tcp") ? 1 : 0; + opts.preferred_transport = -1; + opts.preferred_transport = parser.GetBool("tcp") ? 1 : opts.preferred_transport; + opts.preferred_transport = parser.GetBool("udp") ? 0 : opts.preferred_transport; ps_node_create_subscriber_adv(&node, info->first.c_str(), 0, &sub, &opts); break; } diff --git a/tools/throughput_test.cpp b/tools/throughput_test.cpp index 0b8653f..ea10200 100644 --- a/tools/throughput_test.cpp +++ b/tools/throughput_test.cpp @@ -31,8 +31,7 @@ int main() // okay, since we are publishing with shared pointer we actually need to allocate the string properly auto shared = pubsub::msg::StringSharedPtr(new pubsub::msg::String); - shared->value = new char[strlen(msg.value) + 1]; - strcpy(shared->value, msg.value); + shared->value = msg.value; while (ps_okay()) {