diff --git a/README.md b/README.md index e1d71d0..a852fa9 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,25 @@ This is the code needed to calculate the witness by a circuit compiled with [circom](https://github.com/iden3/circom). +## Server +To compile server implementation: +- Compile Server +``` +g++ -O3 -std=c++17 -DSERVER_ENABLE -fopenmp -pthread calcwit.cpp main.cpp utils.cpp fr.cpp fr.o socket.cpp circuit-1960-32-256-64.cpp -o circuit-1960-32-256-64 -lgmp +``` +- Compile Client +``` +g++ -std=c++17 client.cpp socket.cpp -o client +``` +- Launch Server +``` +./circuit-1960-32-256-64 +``` +- Launch Client +``` +./client input-1960-32-256-64.json circuit-1960-32-256-64_w.wshm +``` + ## License diff --git a/c/client.cpp b/c/client.cpp new file mode 100644 index 0000000..700dab0 --- /dev/null +++ b/c/client.cpp @@ -0,0 +1,51 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include "socket.hpp" + +//main driver program +int main(int argc, char *argv[]) { + if (argc!=3) { + std::string cl = argv[0]; + std::string base_filename = cl.substr(cl.find_last_of("/\\") + 1); + std::cout << "Usage: " << base_filename << " > >\n"; + } else { + int hSocket, read_size, server_reply; + struct sockaddr_in server; + t_witness_msg message; + strcpy(message.inputFile, argv[1]); + strcpy(message.outputFile, argv[2]); + + //Create socket + hSocket = SocketCreate(); + if(hSocket == -1) { + printf("Could not create socket\n"); + return 1; + } + //Connect to remote server + if (SocketConnect(hSocket) < 0) { + perror("connect failed.\n"); + return 1; + } + //Send data to the server, and retry until file created + SocketSend(hSocket, (void *) &message, sizeof(t_witness_msg)); + + while (1) { + if (access(message.outputFile, F_OK) == 0) { + break; + } + sleep(1); + } + + close(hSocket); + shutdown(hSocket,0); + shutdown(hSocket,1); + shutdown(hSocket,2); + return 0; + } +} diff --git a/c/main.cpp b/c/main.cpp index 904091d..738fc85 100644 --- a/c/main.cpp +++ b/c/main.cpp @@ -18,6 +18,10 @@ using json = nlohmann::json; #include "circom.hpp" #include "utils.hpp" +#ifdef SERVER_ENABLE +#include "socket.hpp" +#endif + Circom_Circuit *circuit; @@ -25,6 +29,10 @@ Circom_Circuit *circuit; do { perror(msg); exit(EXIT_FAILURE); } while (0) #define SHMEM_WITNESS_KEY (123456) +#define FAST_LOG2(x) (sizeof(unsigned long)*8 - 1 - __builtin_clzl((unsigned long)(x))) +#define FAST_LOG2_UP(x) (((x) - (1 << FAST_LOG2(x))) ? FAST_LOG2(x) + 1 : FAST_LOG2(x)) + + // assumptions // 1) There is only one key assigned for shared memory. This means @@ -82,18 +90,19 @@ void writeOutShmem(Circom_CalcWit *ctx, std::string filename) { u64 idSection2length = n8*circuit->NVars; fwrite(&idSection2length, 8, 1, write_ptr); + u64 nElems = (1 << (FAST_LOG2_UP(nVars)+1)) + 8; // generate key key_t key = SHMEM_WITNESS_KEY; fwrite(&key, sizeof(key_t), 1, write_ptr); // Setup shared memory - if ((shmid = shmget(key, circuit->NVars * Fr_N64 * sizeof(u64), IPC_CREAT | 0666)) < 0) { + if ((shmid = shmget(key, nElems * Fr_N64 * sizeof(u64), IPC_CREAT | 0666)) < 0) { // preallocated shared memory segment is too small => Retrieve id by accesing old segment // Delete old segment and create new with corret size shmid = shmget(key, 4, IPC_CREAT | 0666); shmctl(shmid, IPC_RMID, NULL); - if ((shmid = shmget(key, circuit->NVars * Fr_N64 * sizeof(u64), IPC_CREAT | 0666)) < 0){ + if ((shmid = shmget(key, nElems * Fr_N64 * sizeof(u64), IPC_CREAT | 0666)) < 0){ status = -1; fwrite(&status, sizeof(status), 1, write_ptr); fclose(write_ptr); @@ -344,78 +353,140 @@ Circom_Circuit *loadCircuit(std::string const &datFileName) { return circuit; } + +void computeWitness(char *inputFile, char *outputFile) { + struct timeval begin, end; + long seconds, microseconds; + double elapsed; + + gettimeofday(&begin,0); + Circom_CalcWit *ctx = new Circom_CalcWit(circuit); + + std::string infilename = inputFile; + gettimeofday(&end,0); + seconds = end.tv_sec - begin.tv_sec; + microseconds = end.tv_usec - begin.tv_usec; + elapsed = seconds + microseconds*1e-6; + + printf("Up to loadJson %.20f\n", elapsed); + + if (hasEnding(infilename, std::string(".bin"))) { + loadBin(ctx, infilename); + } else if (hasEnding(infilename, std::string(".json"))) { + loadJson(ctx, infilename); + } else { + handle_error("Invalid input extension (.bin / .json)"); + } + + ctx->join(); + + // printf("Finished!\n"); + + std::string outfilename = outputFile; + + if (hasEnding(outfilename, std::string(".wtns"))) { + gettimeofday(&end,0); + seconds = end.tv_sec - begin.tv_sec; + microseconds = end.tv_usec - begin.tv_usec; + elapsed = seconds + microseconds*1e-6; + + printf("Up to WriteWtns %.20f\n", elapsed); + writeOutBin(ctx, outfilename); + } else if (hasEnding(outfilename, std::string(".json"))) { + writeOutJson(ctx, outfilename); + } else if (hasEnding(outfilename, std::string(".wshm"))) { + gettimeofday(&end,0); + seconds = end.tv_sec - begin.tv_sec; + microseconds = end.tv_usec - begin.tv_usec; + elapsed = seconds + microseconds*1e-6; + + printf("Up to WriteShmem %.20f\n", elapsed); + writeOutShmem(ctx, outfilename); + } else { + handle_error("Invalid output extension (.bin / .json)"); + } + + delete ctx; + gettimeofday(&end,0); + seconds = end.tv_sec - begin.tv_sec; + microseconds = end.tv_usec - begin.tv_usec; + elapsed = seconds + microseconds*1e-6; + + printf("Total %.20f\n", elapsed); +#ifndef SERVER_ENABLE + exit(EXIT_SUCCESS); +#endif +} + int main(int argc, char *argv[]) { + struct timeval begin, end; + long seconds, microseconds; + double elapsed; + + gettimeofday(&begin,0); +#ifndef SERVER_ENABLE if (argc!=3) { std::string cl = argv[0]; std::string base_filename = cl.substr(cl.find_last_of("/\\") + 1); std::cout << "Usage: " << base_filename << " > >\n"; } else { - - struct timeval begin, end; - long seconds, microseconds; - double elapsed; - - gettimeofday(&begin,0); - std::string datFileName = argv[0]; datFileName += ".dat"; circuit = loadCircuit(datFileName); - // open output - Circom_CalcWit *ctx = new Circom_CalcWit(circuit); - - std::string infilename = argv[1]; - gettimeofday(&end,0); - seconds = end.tv_sec - begin.tv_sec; - microseconds = end.tv_usec - begin.tv_usec; - elapsed = seconds + microseconds*1e-6; - - printf("Up to loadJson %.20f\n", elapsed); - - if (hasEnding(infilename, std::string(".bin"))) { - loadBin(ctx, infilename); - } else if (hasEnding(infilename, std::string(".json"))) { - loadJson(ctx, infilename); - } else { - handle_error("Invalid input extension (.bin / .json)"); - } - - ctx->join(); - - // printf("Finished!\n"); - - std::string outfilename = argv[2]; + gettimeofday(&end,0); + seconds = end.tv_sec - begin.tv_sec; + microseconds = end.tv_usec - begin.tv_usec; + elapsed = seconds + microseconds*1e-6; - if (hasEnding(outfilename, std::string(".wtns"))) { - gettimeofday(&end,0); - seconds = end.tv_sec - begin.tv_sec; - microseconds = end.tv_usec - begin.tv_usec; - elapsed = seconds + microseconds*1e-6; - - printf("Up to WriteWtns %.20f\n", elapsed); - writeOutBin(ctx, outfilename); - } else if (hasEnding(outfilename, std::string(".json"))) { - writeOutJson(ctx, outfilename); - } else if (hasEnding(outfilename, std::string(".wshm"))) { - gettimeofday(&end,0); - seconds = end.tv_sec - begin.tv_sec; - microseconds = end.tv_usec - begin.tv_usec; - elapsed = seconds + microseconds*1e-6; - - printf("Up to WriteShmem %.20f\n", elapsed); - writeOutShmem(ctx, outfilename); - } else { - handle_error("Invalid output extension (.bin / .json)"); - } + printf("Up to computeWitness %.20f\n", elapsed); + // open output + computeWitness(argv[1], argv[2]); - delete ctx; - gettimeofday(&end,0); - seconds = end.tv_sec - begin.tv_sec; - microseconds = end.tv_usec - begin.tv_usec; - elapsed = seconds + microseconds*1e-6; +#else + { + std::string datFileName = argv[0]; + datFileName += ".dat"; - printf("Total %.20f\n", elapsed); - exit(EXIT_SUCCESS); + int circuitInit=0; + t_witness_msg message; + + if (!ServerInit()) { + exit(1); + } + + while(1) { + int sock; + sock = ReceiveMsg((void *) &message, sizeof(t_witness_msg)); + if (!sock) { + continue; + } + std::cout << " Output file " << message.outputFile << "\n"; + std::cout << " Input file " << message.inputFile << "\n"; + + if (!circuitInit) { + std::cout << " Load Circuit " << datFileName << "\n"; + circuit = loadCircuit(datFileName); + circuitInit=1; + + gettimeofday(&end,0); + seconds = end.tv_sec - begin.tv_sec; + microseconds = end.tv_usec - begin.tv_usec; + elapsed = seconds + microseconds*1e-6; + + printf("Up to computeWitness %.20f\n", elapsed); + } + + if (circuitInit) { + std::cout << " Compute Witness " << message.outputFile << "\n"; + computeWitness(message.inputFile, message.outputFile); + } + + SocketClose(sock); + sleep(1); + } +#endif } } + diff --git a/c/socket.cpp b/c/socket.cpp new file mode 100644 index 0000000..047277a --- /dev/null +++ b/c/socket.cpp @@ -0,0 +1,123 @@ +#include +#include +#include +#include +#include +#include +#include "socket.hpp" + +static int SocketDesc; + +//Create a Socket for server communication +short SocketCreate(void) { + short hSocket; + hSocket = socket(AF_INET, SOCK_STREAM, 0); + return hSocket; +} + +int SocketConnect(int hSocket) { + int iRetval=-1; + int ServerPort = SERVER_PORT; + struct sockaddr_in remote= {0}; + remote.sin_addr.s_addr = inet_addr("127.0.0.1"); //Local Host + remote.sin_family = AF_INET; + remote.sin_port = htons(ServerPort); + iRetval = connect(hSocket,(struct sockaddr *)&remote,sizeof(struct sockaddr_in)); + return iRetval; +} + +// Send the data to the server and set the timeout of 20 seconds +int SocketSend(int hSocket,void* Rqst,short lenRqst) { + int shortRetval = -1; + struct timeval tv; + tv.tv_sec = 20; /* 20 Secs Timeout */ + tv.tv_usec = 0; + if(setsockopt(hSocket,SOL_SOCKET,SO_SNDTIMEO,(char *)&tv,sizeof(tv)) < 0) + { + printf("Time Out\n"); + return -1; + } + shortRetval = send(hSocket, Rqst, lenRqst, 0); + return shortRetval; +} + +int BindCreatedSocket(int hSocket) { + int iRetval=-1; + int ClientPort = SERVER_PORT; + struct sockaddr_in remote= {0}; + /* Internet address family */ + remote.sin_family = AF_INET; + /* Any incoming interface */ + remote.sin_addr.s_addr = htonl(INADDR_ANY); + remote.sin_port = htons(ClientPort); /* Local port */ + iRetval = bind(hSocket,(struct sockaddr *)&remote,sizeof(remote)); + return iRetval; +} + + +int ServerInit() { + //Create socket + SocketDesc = SocketCreate(); + if (SocketDesc == -1) { + printf("Could not create socket"); + return 0; + } + //Bind + if( BindCreatedSocket(SocketDesc) < 0) { + //print the error message + perror("bind failed."); + return 0; + } + //Listen + listen(SocketDesc, 3); + + return 1; +} + +//receive the data from the server +int SocketReceive(int hSocket,void* Rsp,short RvcSize) +{ + int shortRetval = -1; + struct timeval tv; + tv.tv_sec = 20; /* 20 Secs Timeout */ + tv.tv_usec = 0; + if(setsockopt(hSocket, SOL_SOCKET, SO_RCVTIMEO,(char *)&tv,sizeof(tv)) < 0) + { + printf("Time Out\n"); + return -1; + } + shortRetval = recv(hSocket, Rsp, RvcSize, 0); + return shortRetval; +} + +int ReceiveMsg(void *message, int len) { + int sock, clientLen; + struct sockaddr_in client; + printf("Waiting for incoming connections...\n"); + clientLen = sizeof(struct sockaddr_in); + //accept connection from an incoming client + sock = accept(SocketDesc,(struct sockaddr *)&client,(socklen_t*)&clientLen); + if (sock < 0) { + perror("accept failed"); + return 0; + } + printf("Connection accepted\n"); + //Receive a reply from the client + if( recv(sock, message, sizeof(t_witness_msg), 0) < 0) { + printf("recv failed"); + return 0; + } + return sock; +} + +void SocketClose(int socket) { + close(socket); +} + +int SendConfirmation(int sock) { + int reply = 1; + if( send(sock, (void *) &reply , sizeof(int), 0) < 0) { + return 0; + } + return 1; +} diff --git a/c/socket.hpp b/c/socket.hpp new file mode 100644 index 0000000..9591e7c --- /dev/null +++ b/c/socket.hpp @@ -0,0 +1,24 @@ + +#ifndef __SOCKET_H__ +#define __SOCKET_H__ + +#define FILE_LEN (1000) +#define SERVER_PORT (90190) + +typedef struct { + char inputFile[FILE_LEN]; + char outputFile[FILE_LEN]; + +}t_witness_msg; + +short SocketCreate(void); +int SocketConnect(int hSocket); +int SocketSend(int hSocket,void* Rqst,short lenRqst); +int SocketReceive(int hSocket,void * Rsp,short RvcSize); +int BindCreatedSocket(int hSocket); +int ServerInit(); +int ReceiveMsg(void *message, int len); +void SocketClose(int socket); +int SendConfirmation(int socket); + +#endif