diff --git a/.github/parse-coverage-report.js b/.github/parse-coverage-report.js new file mode 100644 index 0000000..9483cc2 --- /dev/null +++ b/.github/parse-coverage-report.js @@ -0,0 +1,55 @@ +const readline = require("readline"); + +const rl = readline.createInterface({ + input: process.stdin, + output: process.stdout, + terminal: false, +}); + +const summary = { fail: [], pass: [], skip: [] }; + +rl.on("line", (line) => { + const output = JSON.parse(line); + if ( + output.Action === "pass" || + output.Action === "skip" || + output.Action === "fail" + ) { + if (output.Test) { + summary[output.Action].push(output); + } + } +}); + +function totalTime(entries) { + return entries.reduce((total, l) => total + l.Elapsed, 0); +} + +rl.on("close", () => { + console.log("## πŸ“‹ Tests executed"); + console.log("| | Number of Tests | Total Time |"); + console.log("|--|--|--|"); + console.log( + "| βœ… Passed | %d | %fs |", + summary.pass.length, + totalTime(summary.pass) + ); + console.log( + "| ❌ Failed | %d | %fs |", + summary.fail.length, + totalTime(summary.fail) + ); + console.log( + "| πŸ”œ Skipped | %d | %fs |", + summary.skip.length, + totalTime(summary.skip) + ); + + if (summary.fail.length > 0) { + console.log("\n## Failures\n"); + } + + summary.fail.forEach((test) => { + console.log("* %s (%s) %fs", test.Test, test.Package, test.Elapsed); + }); +}); \ No newline at end of file diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml new file mode 100644 index 0000000..4eb7bd2 --- /dev/null +++ b/.github/workflows/main.yml @@ -0,0 +1,75 @@ +name: Build and Test + +on: + push: + branches: + - main + pull_request: + +jobs: + job_go_checks: + runs-on: ubuntu-latest + defaults: + run: + shell: bash + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 1 + - name: Set up Go environment + uses: actions/setup-go@v5 + with: + go-version: "1.24" + - name: Tidy go module + run: | + go mod tidy + if [[ $(git status --porcelain) ]]; then + git diff + echo + echo "go mod tidy made these changes, please run 'go mod tidy' and include those changes in a commit" + exit 1 + fi + - name: Run gofumpt + run: diff -u <(echo -n) <(go run mvdan.cc/gofumpt@@latest -d .) + - name: Run go vet + run: go vet ./... + + job_go_test: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 2 + - name: Set up Go environment + uses: actions/setup-go@v5 + with: + go-version: "1.24" + - name: Run Go test -race + run: go test ./... -v -race -timeout=1h + - name: Rerun Go test to generate coverage report + run: | + go test -v -timeout 15m -coverprofile=./cover.out -json ./... > tests.log + - name: Convert report to html + run: go tool cover -html=cover.out -o cover.html + - name: Print coverage report + run: | + set -o pipefail && cat tests.log | node .github/parse-coverage-report.js >> $GITHUB_STEP_SUMMARY + echo $GITHUB_STEP_SUMMARY + - name: Print coverage report + run: | + go tool cover -func=cover.out > ./cover.txt + echo "
πŸ“ Tests coverage" >> $GITHUB_STEP_SUMMARY + echo -e "\n\`\`\`" >> $GITHUB_STEP_SUMMARY + cat ./cover.txt >> $GITHUB_STEP_SUMMARY + echo -e "\`\`\`\n
" >> $GITHUB_STEP_SUMMARY + - name: Store coverage report + uses: actions/upload-artifact@v4 + with: + name: report + path: | + tests.log + cover.txt + cover.out + cover.html diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..0ad25db --- /dev/null +++ b/LICENSE @@ -0,0 +1,661 @@ + GNU AFFERO GENERAL PUBLIC LICENSE + Version 3, 19 November 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU Affero General Public License is a free, copyleft license for +software and other kinds of works, specifically designed to ensure +cooperation with the community in the case of network server software. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +our General Public Licenses are intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + Developers that use our General Public Licenses protect your rights +with two steps: (1) assert copyright on the software, and (2) offer +you this License which gives you legal permission to copy, distribute +and/or modify the software. + + A secondary benefit of defending all users' freedom is that +improvements made in alternate versions of the program, if they +receive widespread use, become available for other developers to +incorporate. Many developers of free software are heartened and +encouraged by the resulting cooperation. However, in the case of +software used on network servers, this result may fail to come about. +The GNU General Public License permits making a modified version and +letting the public access it on a server without ever releasing its +source code to the public. + + The GNU Affero General Public License is designed specifically to +ensure that, in such cases, the modified source code becomes available +to the community. It requires the operator of a network server to +provide the source code of the modified version running there to the +users of that server. Therefore, public use of a modified version, on +a publicly accessible server, gives the public access to the source +code of the modified version. + + An older license, called the Affero General Public License and +published by Affero, was designed to accomplish similar goals. This is +a different license, not a version of the Affero GPL, but Affero has +released a new version of the Affero GPL which permits relicensing under +this license. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU Affero General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Remote Network Interaction; Use with the GNU General Public License. + + Notwithstanding any other provision of this License, if you modify the +Program, your modified version must prominently offer all users +interacting with it remotely through a computer network (if your version +supports such interaction) an opportunity to receive the Corresponding +Source of your version by providing access to the Corresponding Source +from a network server at no charge, through some standard or customary +means of facilitating copying of software. This Corresponding Source +shall include the Corresponding Source for any work covered by version 3 +of the GNU General Public License that is incorporated pursuant to the +following paragraph. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the work with which it is combined will remain governed by version +3 of the GNU General Public License. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU Affero General Public License from time to time. Such new versions +will be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU Affero General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU Affero General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU Affero General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published + by the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If your software can interact with users remotely through a computer +network, you should also make sure that it provides a way for users to +get its source. For example, if your program is a web application, its +interface could display a "Source" link that leads users to an archive +of the code. There are many ways you could offer source, and different +solutions will be better for different programs; see section 13 for the +specific requirements. + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU AGPL, see +. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..9e5f3f8 --- /dev/null +++ b/Makefile @@ -0,0 +1,29 @@ +.PHONY: run demo clean + +demo: clean-demo + @echo "Building demo image..." + docker build -f docker/Dockerfile.demo -t demo-simpleauthlink . + @echo "Running demo container..." + docker run --name demo-simpleauthlink --env-file demo.env -p ${PORT}:80 -d demo-simpleauthlink + +clean-demo: + @echo "Cleaning up previous containers and images..." + @docker rm -f demo-simpleauthlink 2>/dev/null || true + @echo "Containers cleaned up" + @docker rmi -f demo-simpleauthlink 2>/dev/null || true + @echo "Images cleaned up" + @echo "Cleaning up done" + +api: clean-api + @echo "Building API image..." + docker build -f docker/Dockerfile.prod -t simpleauthlink . + @echo "Running API container..." + docker run --name simpleauthlink --env-file .env -p ${PORT}:80 simpleauthlink + +clean-api: + @echo "Cleaning up previous containers and images..." + @docker rm -f simpleauthlink 2>/dev/null || true + @echo "Containers cleaned up" + @docker rmi -f simpleauthlink 2>/dev/null || true + @echo "Images cleaned up" + @echo "Cleaning up done" \ No newline at end of file diff --git a/README.md b/README.md index d1e9f50..2496487 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,112 @@ -# SimpleAuthLink API +[![Last release](https://img.shields.io/github/v/release/simpleauthlink/authapi?color=purple)](https://github.com/simpleauthlink/authapi/releases/latest) +[![GoDoc](https://godoc.org/github.com/simpleauthlink/authapi?status.svg)](https://godoc.org/github.com/simpleauthlink/authapi) +[![Go Report Card](https://goreportcard.com/badge/github.com/simpleauthlink/authapi)](https://goreportcard.com/report/github.com/simpleauthlink/authapi) +[![Build and Test](https://github.com/simpleauthlink/authapi/actions/workflows/main.yml/badge.svg?branch=main)](https://github.com/simpleauthlink/authapi/actions/workflows/main.yml) +[![license](https://img.shields.io/github/license/simpleauthlink/authapi)](LICENSE) -WIP document, read more in the project website: [https://simpleauth.link](https://simpleauth.link). \ No newline at end of file + + + +# SimpleAuth.link API + +> Passwordless authentication for your users using just an email address. + +This repository contains the source code of the SimpleAuth.link API Service. + +Read full [documentation here](https://docs.simpleauth.link). + +--- + +## Technical Details πŸ’» + +### Token Generation Process πŸ”‘ + +* By leveraging the Ed25519 signature algorithm, the service deterministically generates a private key using your App ID and secret. +* This ensures that each token is cryptographically secure and uniquely tied to your application, eliminating the need to store sensitive keys. + +### Stateless Architecture πŸ•ŠοΈ + +* SimpleAuth.link works without a traditional database. It does not store any user data, including email addresses, on its servers. +* Instead, the data generated is self-contained, requiring no further information or state to be used. This stateless design increases security and reduces the risk of data breaches. + +## Development πŸ§‘β€πŸ’» + +### Prerequisites πŸ“ + - Go (version 1.24 or later recommended) + - Docker (optional, for containerized deployment) + +### Clone the Repository πŸ“₯ +```sh +git clone --branch v2 https://github.com/SimpleAuthLink/authapi.git +cd authapi +``` + +### Code Structure πŸͺœ +* Modularity: The repository is structured into packages to simplify testing and future expansion. + - `api/`: Contains the core API endpoint definitions and routing logic. + - `cmd/`: Entry points and command-line utilities for running the API server. + - `docker/`: Dockerfiles and configuration for containerized builds. + - `internal/`: Internal packages and libraries used across the application. + - `notification/`: Code handling user notifications. + - `token/`: Modules for token creation and management. + - `.github/`: GitHub-specific workflows and configurations (CI/CD, issue templates, etc.). + + This layout supports modular development and clear separation of concerns across different parts of the API service. + +* Testing: Use the standard Go testing framework to run tests. You can run tests with: + ```sh + go test ./... + ``` + +### Run with go 🦫 +For development purposes, you can run the API server directly with Go. +```sh +go run ./cmd/authapi -h +``` +```sh +Usage of authapi: + -email-addr string + email account address + -email-host string + email server host + -email-pass string + email account password + -email-port int + email server port (default 587) + -host string + service host (default "0.0.0.0") + -port int + service host (default 8080) + -secret string + secret used to generate the tokens (default "simpleauthlink-secret") +``` + +### Run with docker 🐳 + +1. **Prepare the Environment File** + + Copy the `example.env` file to `.env` and edit the file to fill in your parameters: + ```bash + HOST="localhost" + PORT=8080 + EMAIL_ADDR="test@test.com" + EMAIL_PASS="smtp_server_password" + EMAIL_HOST="smtp.example.com" + EMAIL_PORT=587 + SECRET="my_backend_secret" + ``` + +2. **Build the Docker Image** + + Run the following command in the root of your project to build the image: + ```bash + docker build -f docker/Dockerfile.prod -t simpleauthlink . + ``` + +3. **Run the Docker Container** + + Once the image is built, start a container using the environment file: + ```bash + docker run --name simpleauthlink --env-file .env -p 8080:80 simpleauthlink + ``` + This command maps the container’s port 80 to port 8080 on your host. \ No newline at end of file diff --git a/api/apps.go b/api/apps.go deleted file mode 100644 index cd78c48..0000000 --- a/api/apps.go +++ /dev/null @@ -1,177 +0,0 @@ -package api - -import ( - "encoding/hex" - "fmt" - - "github.com/simpleauthlink/authapi/db" - "github.com/simpleauthlink/authapi/helpers" -) - -// authApp method creates a new app based on the provided name, email, redirectURL -// and duration. It returns the app id and the app secret. If the name, email or -// redirectURL are empty, it returns an error. If the duration is less than the -// minimum duration, it returns an error. If something fails during the process, -// it returns an error. The app id and the app secret are generated based on the -// email using the generateApp function. The app is stored in the database using -// the app id as the key. The secret is stored in the database using the hashed -// secret as the key. The hashed secret is required to be compared with the -// secret provided by the user in the requests. -func (s *Service) authApp(name, email, redirectURL string, duration uint64) (string, string, error) { - // check if the name, email, and redirectURL are not empty - if len(name) == 0 || len(email) == 0 || len(redirectURL) == 0 { - return "", "", fmt.Errorf("name, email, and redirectURL are required") - } - // check if the duration is valid - if duration < helpers.MinTokenDuration { - return "", "", fmt.Errorf("duration must be at least %d seconds", helpers.MinTokenDuration) - } - // compose the app struct for the database - appData := &db.App{ - Name: name, - AdminEmail: email, - SessionDuration: duration, - RedirectURL: redirectURL, - UsersQuota: helpers.DefaultUsersQuota, - } - // generate app based on email - appId, secret, hSecret, err := generateApp(appData.AdminEmail) - if err != nil { - return "", "", err - } - // store app in the database - if err := s.db.SetApp(appId, appData); err != nil { - return "", "", err - } - // store secret in the database - if err := s.db.SetSecret(hSecret, appId); err != nil { - return "", "", err - } - return appId, secret, nil -} - -// appMetadata method retrieves the app data based on the app id. If the app id is -// empty, it returns an error. If something fails during the process, it returns -// an error. The app data includes the name, the email of the admin, the redirect -// URL, the duration, the users quota, and the current users. The current users -// are retrieved from the database using the app id to count the number of tokens -// for the app. -func (s *Service) appMetadata(appId string) (AppData, error) { - dbApp, err := s.db.AppById(appId) - if err != nil { - return AppData{}, err - } - app := AppData{ - Name: dbApp.Name, - Email: dbApp.AdminEmail, - RedirectURL: dbApp.RedirectURL, - Duration: dbApp.SessionDuration, - UsersQuota: dbApp.UsersQuota, - } - // get the number of current tokens for the app, if it fails, it returns 0 - app.CurrentUsers, _ = s.db.CountTokens(appId) - return app, nil -} - -// updateAppMetadata method updates the app metadata based on the app id, name, -// redirectURL, and duration. If the app id is empty, it returns an error. If -// the duration is non zero an less than the minimum duration, it returns an -// error. If something fails during the process, it returns an error. -func (s *Service) updateAppMetadata(appId, name, redirectURL string, duration uint64) error { - // check if the app id is not empty - if len(appId) == 0 { - return fmt.Errorf("app id is required") - } - // check if the duration is valid - if duration != 0 && duration < helpers.MinTokenDuration { - return fmt.Errorf("duration must be at least %d seconds", helpers.MinTokenDuration) - } - // get app from the database - app, err := s.db.AppById(appId) - if err != nil { - return err - } - // update app metadata - if name != "" { - app.Name = name - } - if redirectURL != "" { - app.RedirectURL = redirectURL - } - if duration != 0 { - app.SessionDuration = duration - } - // store app in the database - return s.db.SetApp(appId, app) -} - -// removeApp method removes an app based on the app id. If the app id is empty, -// it returns an error. If something fails during the process, it returns an -// error. It also removes all the tokens for the app from the database using -// the app id as the prefix to find them. -func (s *Service) removeApp(appId string) error { - // check if the app id is not empty - if len(appId) == 0 { - return fmt.Errorf("app id is required") - } - // remove all the tokens for the app from the database, using the app id as - // the prefix - if err := s.db.DeleteTokensByPrefix(appId); err != nil { - return err - } - // remove app from the database - return s.db.DeleteApp(appId) -} - -func (s *Service) validSecret(appId, rawSecret string) bool { - secret, err := helpers.Hash(rawSecret, helpers.SecretSize) - if err != nil { - return false - } - valid, err := s.db.ValidSecret(secret, appId) - if err != nil { - return false - } - return valid -} - -// generateApp function generates an app based on the email. It returns the app -// id, the app secret and the hashed secret. If the email is empty or something -// fails during the process, it returns an error. The app id is generated -// hashing the email with a length of 4 bytes. The app secret is generated -// using the appSecret function. -func generateApp(email string) (string, string, string, error) { - if len(email) == 0 { - return "", "", "", fmt.Errorf("email is required") - } - // hash email - hEmail, err := helpers.Hash(email, helpers.EmailHashSize) - if err != nil { - return "", "", "", err - } - bAppNonce := helpers.RandBytes(helpers.AppNonceSize) - hAppNonce := hex.EncodeToString(bAppNonce) - appId := hEmail + hAppNonce - // generate secret - secret, hSecret, err := appSecret() - if err != nil { - return "", "", "", err - } - return appId, secret, hSecret, nil -} - -// appSecret function generates an new app secret. It returns the secret, the -// hashed secret and an error if something fails during the process. The secret -// is a random sequence of 16 bytes encoded as a hexadecimal string. The hashed -// secret is required to store the secret in the database without exposing it. -func appSecret() (string, string, error) { - // generate secret - bSecret := helpers.RandBytes(helpers.SecretSize) - secret := hex.EncodeToString(bSecret) - // hash secret - hSecret, err := helpers.Hash(secret, helpers.SecretSize) - if err != nil { - return "", "", err - } - return secret, hSecret, nil -} diff --git a/api/errors.go b/api/errors.go new file mode 100644 index 0000000..6613cf8 --- /dev/null +++ b/api/errors.go @@ -0,0 +1,27 @@ +package api + +import ( + "net/http" + + "github.com/simpleauthlink/authapi/api/io" +) + +var ( + // Decode data errors + DecodeAppIDRequestErr = io.NewAPIError(1001, http.StatusBadRequest).With("could not decode app id request") + DecodeTokenRequestErr = io.NewAPIError(1002, http.StatusBadRequest).With("could not decode token request") + DecodeTokenStatusRequestErr = io.NewAPIError(1003, http.StatusBadRequest).With("could not decode token status request") + // Encode data errors + EncodeAppIDResponseErr = io.NewAPIError(1010, http.StatusInternalServerError).With("could not encode app id response") + EncodeTokenStatusResponseErr = io.NewAPIError(1011, http.StatusInternalServerError).With("could not encode token status response") + // Bad request errors + InvalidAppHeadersErr = io.NewAPIError(1020, http.StatusBadRequest).With("invalid app headers") + InvalidAppIDErr = io.NewAPIError(1021, http.StatusBadRequest).With("invalid app id") + InvalidAppSecretErr = io.NewAPIError(1022, http.StatusBadRequest).With("invalid app secret") + InvalidDemoEmailInboxErr = io.NewAPIError(1023, http.StatusBadRequest).With("invalid demo email inbox") + // Internal errors + GenerateTokenErr = io.NewAPIError(1030, http.StatusInternalServerError).With("could not generate token") + GenerateEmailErr = io.NewAPIError(1031, http.StatusInternalServerError).With("could not generate email") + SendEmailErr = io.NewAPIError(1032, http.StatusInternalServerError).With("could not send email") + InternalErr = io.NewAPIError(1033, http.StatusInternalServerError).With("internal server error") +) diff --git a/api/handlers.go b/api/handlers.go index dc2c72f..95f0ae4 100644 --- a/api/handlers.go +++ b/api/handlers.go @@ -1,321 +1,168 @@ package api import ( - "encoding/json" "fmt" - "io" - "log" "net/http" - "github.com/simpleauthlink/authapi/db" - "github.com/simpleauthlink/authapi/email" - "github.com/simpleauthlink/authapi/helpers" + "github.com/simpleauthlink/authapi/api/io" + "github.com/simpleauthlink/authapi/notification" + "github.com/simpleauthlink/authapi/notification/templates/login" + "github.com/simpleauthlink/authapi/token" ) -// userTokenHandler method generates a token for the user and sends it via email -// to the user's email address. The token is generated based on the app id -// and the user's email address. The token is stored in the database with an -// expiration time. It gets the app secret from the helpers.AppSecretHeader -// header and the user's email address from the request body. If it success it -// sends an "Ok" response. If something goes wrong, it sends an internal server -// error response. If the app secret is missing or the request body is invalid, -// it sends a bad request response. -func (s *Service) userTokenHandler(w http.ResponseWriter, r *http.Request) { - // read the app token header - appSecret := r.Header.Get(helpers.AppSecretHeader) - if appSecret == "" { - http.Error(w, "missing app token", http.StatusBadRequest) - return - } - // read body - defer r.Body.Close() - body, err := io.ReadAll(r.Body) - if err != nil { - log.Println("ERR: error reading request body:", err) - http.Error(w, "error reading request body", http.StatusInternalServerError) - return - } - // parse request - req := &TokenRequest{} - if err := json.Unmarshal(body, req); err != nil { - log.Println("ERR: error parsing request body:", err) - http.Error(w, "error parsing request body", http.StatusBadRequest) - return - } - // check if the email is allowed - if !s.emailQueue.Allowed(req.Email) { - http.Error(w, "disallowed domain", http.StatusBadRequest) - return - } - // generate token - magicLink, token, appName, err := s.magicLink(appSecret, req.Email, req.RedirectURL, req.Duration) - if err != nil { - log.Println("ERR: error generating token:", err) - http.Error(w, "error generating token", http.StatusInternalServerError) - return - } - // compose and push the email to the queue to be sent, if it fails, delete - // the token from the database, log the error and send an error response - emailData := email.NewUserEmailData(appName, req.Email, magicLink, token) - emailBody, err := email.ParseTemplate(s.cfg.TokenEmailTemplate, emailData) - if err != nil { - log.Println("ERR: error parsing email template:", err) - http.Error(w, "error parsing email template", http.StatusInternalServerError) - return - } - if err := s.emailQueue.Push(&email.Email{ - To: req.Email, - Subject: fmt.Sprintf(userTokenSubject, appName), - Body: emailBody, - }); err != nil { - log.Println("ERR: error sending email:", err) - if err := s.db.DeleteToken(db.Token(token)); err != nil { - log.Println("ERR: error deleting token:", err) - } - http.Error(w, "error sending email", http.StatusInternalServerError) - return - } - // send response - if _, err := w.Write([]byte("Ok")); err != nil { - log.Println("ERR: error sending response:", err) - http.Error(w, "error sending response", http.StatusInternalServerError) - return - } +// generateAppIDHandler handles the request to generate an app id it decodes +// the app data from the request body and returns the app id in the response +// body. Every app data information is required to generate the app id. The +// app id is a self-contained representation of the app that can be used to +// generate tokens. It is created by encoding the app as a base64-encoded +// byte slice resulting in concatenating the app name, redirect uri, and +// session duration. +func (s *Service) generateAppIDHandler(w http.ResponseWriter, r *http.Request) { + // decode the app data from the request body + req := new(io.Request[AppIDRequest]) + if err := req.Read(r); err != nil { + DecodeAppIDRequestErr.WithErr(err).Write(w) + return + } + // create the app from the data and check if it is valid + app, appSecret := req.Data.parseApp() + secret := new(token.Secret).SetParts([]byte(s.cfg.Secret), []byte(appSecret)) + app.SetSecret(secret) + if !app.Valid(secret.Hash()) { + InvalidAppIDErr.Write(w) + return + } + // return the app id + io.ResponseWith(&AppIDResponse{app.ID(secret).String()}).WriteJSON(w) } -// validateUserTokenHandler method validates the user token. It gets the token -// from the helpers.TokenQueryParam query string and checks if it is valid. If -// the token is valid, it sends a response with the "Ok" message. If the token -// is invalid, it sends an unauthorized response. If the token is missing, it -// sends a bad request response. -func (s *Service) validateUserTokenHandler(w http.ResponseWriter, r *http.Request) { - // read the app token header - appSecret := r.Header.Get(helpers.AppSecretHeader) - if appSecret == "" { - http.Error(w, "missing app token", http.StatusBadRequest) - return - } - // get the token from the query - token := r.URL.Query().Get(helpers.TokenQueryParam) - if token == "" { - http.Error(w, "missing token", http.StatusBadRequest) - return - } - // validate the token - if !s.validUserToken(token, appSecret) { - http.Error(w, "invalid token", http.StatusUnauthorized) - return - } - if _, err := w.Write([]byte("Ok")); err != nil { - log.Println("ERR: error sending response:", err) - http.Error(w, "error sending response", http.StatusInternalServerError) - return - } -} - -// appTokenHandler method generates creates an app in the service, it generates -// an app id and a secret for the app. It sends the app id and the secret via -// email to the app's email address. It gets the app name, email, callback, and -// duration from the request body. If it success it sends an "Ok" response. If -// something goes wrong, it sends an internal server error response. If the -// request body is invalid, it sends a bad request response. -func (s *Service) appTokenHandler(w http.ResponseWriter, r *http.Request) { - // read body - defer r.Body.Close() - body, err := io.ReadAll(r.Body) - if err != nil { - log.Println("ERR: error reading request body:", err) - http.Error(w, "error reading request body", http.StatusInternalServerError) - return - } - app := &AppData{} - if err := json.Unmarshal(body, app); err != nil { - log.Println("ERR: error parsing request body:", err) - http.Error(w, "error parsing request body", http.StatusBadRequest) - return - } - // check if the email is allowed - if !s.emailQueue.Allowed(app.Email) { - http.Error(w, "disallowed domain", http.StatusBadRequest) - return - } - // generate token - appId, secret, err := s.authApp(app.Name, app.Email, app.RedirectURL, app.Duration) +func (s *Service) requestTokenHandler(w http.ResponseWriter, r *http.Request) { + // get the app id from the request header + strAppID, strAppSecret, err := appConfigFromRequest(r) if err != nil { - log.Println("ERR: error generating token:", err) - http.Error(w, "error generating token", http.StatusInternalServerError) + InvalidAppHeadersErr.WithErr(err).Write(w) return } - emailData := email.NewAppEmailData(appId, app.Name, app.RedirectURL, secret, app.Email) - emailBody, err := email.ParseTemplate(s.cfg.AppEmailTemplate, emailData) - if err != nil { - log.Println("ERR: error parsing email template:", err) - http.Error(w, "error parsing email template", http.StatusInternalServerError) + // decode the app id get the app from it + appID := new(token.AppID).SetString(strAppID) + app := new(token.App).SetID(appID) + // compose the app secret with both parts + secret := new(token.Secret).SetParts([]byte(s.cfg.Secret), []byte(strAppSecret)) + if !secret.Valid() { + InvalidAppSecretErr.Write(w) return } - // compose and push the email to the queue to be sent if it fails, delete - // the app from the database, log the error and send an error response - if err := s.emailQueue.Push(&email.Email{ - To: app.Email, - Subject: fmt.Sprintf(appTokenSubject, app.Name), - Body: emailBody, - }); err != nil { - log.Println("ERR: error sending email:", err) - if err := s.removeApp(appId); err != nil { - log.Println("ERR: error deleting app:", err) - } - http.Error(w, "error sending email", http.StatusInternalServerError) + // check if the app id is valid (it should be a valid app) + if !app.Valid(secret.Hash()) { + InvalidAppIDErr.Write(w) return } - // send response - if _, err := w.Write([]byte("Ok")); err != nil { - log.Println("ERR: error sending response:", err) - http.Error(w, "error sending response", http.StatusInternalServerError) + // decode the token request from the request body + req := new(io.Request[TokenRequest]) + if err := req.Read(r); err != nil { + DecodeTokenRequestErr.WithErr(err).Write(w) return } -} - -// appHandler method gets the app metadata from the service. It gets the app id -// from the token provided in the URL query. If the token is missing, it sends -// a bad request response. If the token is invalid or is not an admin token, it -// sends an unauthorized response. If the app is not found, it sends a not found -// response. If it success it sends the app metadata. If something goes wrong, -// it sends an internal server error response. -func (s *Service) appHandler(w http.ResponseWriter, r *http.Request) { - // read the app token header - appSecret := r.Header.Get(helpers.AppSecretHeader) - if appSecret == "" { - http.Error(w, "missing app token", http.StatusBadRequest) + // generate user token + token := appID.GenerateToken(*secret, req.Data.Email) + if token == nil { + GenerateTokenErr.With(req.Data.Email).Write(w) return } - // get the token from the query - token := r.URL.Query().Get(helpers.TokenQueryParam) - if token == "" { - http.Error(w, "missing token", http.StatusBadRequest) - return + // compose the email with the token + loginData := login.Data{ + AppName: app.Name, + Email: req.Data.Email, + Token: token.String(), + Link: app.RedirectURI + token.String(), } - // validate the token and get the app id - appId, valid := s.validAdminToken(token, appSecret) - if !valid { - http.Error(w, "invalid token", http.StatusUnauthorized) - return - } - // get the app from the database - app, err := s.appMetadata(appId) - if err != nil { - if err == db.ErrAppNotFound { - http.Error(w, "app not found", http.StatusNotFound) - return - } - log.Println("ERR: error getting app:", err) - http.Error(w, "error getting app", http.StatusInternalServerError) - return - } - // encode the app metadata - res, err := json.Marshal(&app) + loginEmail, err := login.Template.Compose(notification.NotificationParams{ + To: req.Data.Email, + Subject: loginData.Subject(), + }, loginData) if err != nil { - log.Println("ERR: error marshaling app:", err) - http.Error(w, "error marshaling app", http.StatusInternalServerError) + GenerateEmailErr.WithErr(err).Write(w) return } - // send response - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - if _, err := w.Write(res); err != nil { - log.Println("ERR: error sending response:", err) - http.Error(w, "error sending response", http.StatusInternalServerError) + // push the email to the notification queue + if err := s.nq.Push(loginEmail); err != nil { + SendEmailErr.WithErr(err).Write(w) return } + io.OkResponse().WriteJSON(w) } -// updateAppHandler method updates an app in the service. It gets the app id -// from the URL path and the app name, callback, and duration from the request -// body. If the app id is missing, it sends a bad request response. If the app -// is not found, it sends a not found response. If it success it sends an Ok -// response. If something goes wrong, it sends an internal server error -// response. -func (s *Service) updateAppHandler(w http.ResponseWriter, r *http.Request) { - // read the app token header - appSecret := r.Header.Get(helpers.AppSecretHeader) - if appSecret == "" { - http.Error(w, "missing app token", http.StatusBadRequest) - return - } - // get the token from the query - token := r.URL.Query().Get(helpers.TokenQueryParam) - if token == "" { - http.Error(w, "missing token", http.StatusBadRequest) - return - } - // validate the token and get the app id - appId, valid := s.validAdminToken(token, appSecret) - if !valid { - http.Error(w, "invalid token", http.StatusUnauthorized) - return - } - // read body - defer r.Body.Close() - body, err := io.ReadAll(r.Body) +func (s *Service) verifyTokenHandler(w http.ResponseWriter, r *http.Request) { + // get the app id from the request header + strAppID, strAppSecret, err := appConfigFromRequest(r) if err != nil { - log.Println("ERR: error reading request body:", err) - http.Error(w, "error reading request body", http.StatusInternalServerError) + InvalidAppHeadersErr.WithErr(err).Write(w) return } - // decode the app from the request - app := &AppData{} - if err := json.Unmarshal(body, app); err != nil { - log.Println("ERR: error parsing request body:", err) - http.Error(w, "error parsing request body", http.StatusBadRequest) + // decode the app id get the app from it + appID := new(token.AppID).SetString(strAppID) + app := new(token.App).SetID(appID) + // compose the app secret with both parts + secret := new(token.Secret).SetParts([]byte(s.cfg.Secret), []byte(strAppSecret)) + if !secret.Valid() { + InvalidAppSecretErr.Write(w) return } - // update the app in the database - if err := s.updateAppMetadata(appId, app.Name, app.RedirectURL, app.Duration); err != nil { - log.Println("ERR: error updating app:", err) - http.Error(w, "error updating app", http.StatusInternalServerError) + // check if the app id is valid (it should be a valid app) + if !app.Valid(secret.Hash()) { + InvalidAppIDErr.Write(w) return } - // send response - if _, err := w.Write([]byte("Ok")); err != nil { - log.Println("ERR: error sending response:", err) - http.Error(w, "error sending response", http.StatusInternalServerError) + // decode the token status request from the request body + req := new(io.Request[TokenStatusRequest]) + if err := req.Read(r); err != nil { + DecodeTokenStatusRequestErr.WithErr(err).Write(w) return } + // check if the token is valid + tkn := new(token.Token).SetString(req.Data.Token) + exp := tkn.Expiration().Time() + io.ResponseWith(&TokenStatusResponse{ + Valid: appID.VerifyToken(*tkn, *secret, req.Data.Email), + Expiration: exp, + }).WriteJSON(w) } -// delAppHandler method deletes an app from the service. It gets the app id from -// the token provided in the URL query. If the token is missing, it sends a bad -// request response. If the token is invalid or is not an admin token, it sends -// an unauthorized response. If it success it sends an Ok response. If something -// goes wrong, it sends an internal server error response. -func (s *Service) delAppHandler(w http.ResponseWriter, r *http.Request) { - // read the app token header - appSecret := r.Header.Get(helpers.AppSecretHeader) - if appSecret == "" { - http.Error(w, "missing app token", http.StatusBadRequest) - return - } - // get the token from the query - token := r.URL.Query().Get(helpers.TokenQueryParam) - if token == "" { - http.Error(w, "missing token", http.StatusBadRequest) - return - } - // validate the token and get the app id - appId, valid := s.validAdminToken(token, appSecret) - if !valid { - http.Error(w, "invalid token", http.StatusUnauthorized) - return - } - // remove the app from the service - if err := s.removeApp(appId); err != nil { - log.Println("ERR: error deleting app:", err) - http.Error(w, "error deleting app", http.StatusInternalServerError) - return - } - // send response - if _, err := w.Write([]byte("Ok")); err != nil { - log.Println("ERR: error sending response:", err) - http.Error(w, "error sending response", http.StatusInternalServerError) - return +func (s *Service) healthCheckHandler(w http.ResponseWriter, r *http.Request) { + io.OkResponse().Write(w) +} + +func (s *Service) demoInboxHandler(w http.ResponseWriter, r *http.Request) { + // get the email from get parameters + email := r.URL.Query().Get("email") + if email == "" { + InvalidDemoEmailInboxErr.Write(w) + return + } + // set http headers required for SSE + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + // create a channel for client disconnection + clientGone := r.Context().Done() + // create a response controller + rc := http.NewResponseController(w) + for { + select { + case <-s.ctx.Done(): + case <-clientGone: + return + case msg := <-s.demoMailInbox: + // find the token in the email + if testToken := login.FindToken(email, msg); testToken != nil { + // send an event to the client with the token in the "data" field + if _, err := fmt.Fprintf(w, "data: %s\n\n", testToken); err != nil { + return + } + if err := rc.Flush(); err != nil { + return + } + } + } } } diff --git a/api/handlers_test.go b/api/handlers_test.go new file mode 100644 index 0000000..bc55131 --- /dev/null +++ b/api/handlers_test.go @@ -0,0 +1,309 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "testing" + "time" + + apiio "github.com/simpleauthlink/authapi/api/io" + "github.com/simpleauthlink/authapi/notification/email" + "github.com/simpleauthlink/authapi/notification/templates/login" + "github.com/simpleauthlink/authapi/token" +) + +type testCaseAPIHandler[ReqType, ResType any] struct { + name string + method string + endpoint string + header http.Header + request *ReqType + response *ResType + err *apiio.APIError +} + +func (testCase testCaseAPIHandler[Rq, Rs]) url() string { + return fmt.Sprintf("%s%s", testServerApiURL, testCase.endpoint) +} + +func (testCase testCaseAPIHandler[Rq, Rs]) Run(t *testing.T) { + t.Run(testCase.name, func(t *testing.T) { + var reqBuffer io.Reader + if testCase.request != nil { + rawBody, err := json.Marshal(testCase.request) + if err != nil { + t.Fatalf("could not marshal request: %v", err) + } + reqBuffer = bytes.NewReader(rawBody) + } + req, err := http.NewRequest(testCase.method, testCase.url(), reqBuffer) + if err != nil { + t.Fatalf("could not create request: %v", err) + } + req.Header = testCase.header + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("could not send request: %v", err) + } + defer resp.Body.Close() + switch { + case testCase.err != nil: + if resp.StatusCode != testCase.err.StatusCode { + t.Fatalf("expected status code: %d, got: %d", testCase.err.StatusCode, resp.StatusCode) + } + err := new(apiio.APIError) + if err := json.NewDecoder(resp.Body).Decode(err); err != nil { + t.Fatalf("could not decode error response: %v", err) + } + if err.Code != testCase.err.Code { + t.Fatalf("expected error code: %d, got: %d", testCase.err.Code, err.Code) + } + return + case testCase.response != nil: + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status code: %d, got: %d", http.StatusOK, resp.StatusCode) + } + expected, err := json.Marshal(testCase.response) + if err != nil { + t.Fatalf("could not marshal response: %v", err) + } + res, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("could not read response: %v", err) + } + if !bytes.Equal(bytes.TrimSpace(expected), bytes.TrimSpace(res)) { + t.Fatalf("expected response: %s, got: %s", expected, res) + } + default: + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status code: %d, got: %d", http.StatusOK, resp.StatusCode) + } + } + }) +} + +func TestGenerateAppIDHandler(t *testing.T) { + testApp := &token.App{ + Name: testAppName, + RedirectURI: testAppRedirectURL, + SessionDuration: testAppSessionDuration, + } + secret := new(token.Secret).SetParts([]byte(testServerSecret), []byte(testAppSecret)) + testApp.SetSecret(secret) + testCaseAPIHandler[AppIDRequest, AppIDResponse]{ + name: "valid request", + method: http.MethodPost, + endpoint: AppsPath, + request: &AppIDRequest{ + Name: testApp.Name, + RedirectURL: testApp.RedirectURI, + Duration: testApp.SessionDuration.String(), + Secret: testAppSecret, + }, + response: &AppIDResponse{ + ID: testApp.ID(secret).String(), + }, + }.Run(t) + testCaseAPIHandler[AppIDRequest, AppIDResponse]{ + name: "no request", + method: http.MethodPost, + endpoint: AppsPath, + request: nil, + err: DecodeAppIDRequestErr, + }.Run(t) + testCaseAPIHandler[AppIDRequest, AppIDResponse]{ + name: "invalid request", + method: http.MethodPost, + endpoint: AppsPath, + request: &AppIDRequest{ + Name: testAppName, + RedirectURL: testAppRedirectURL, + Duration: time.Second.String(), + Secret: testAppSecret, + }, + err: InvalidAppIDErr, + }.Run(t) +} + +func TestRequestTokenAndStatusHandler(t *testing.T) { + testApp := &token.App{ + Name: testAppName, + RedirectURI: testAppRedirectURL, + SessionDuration: testAppSessionDuration, + } + secret := new(token.Secret).SetParts([]byte(testServerSecret), []byte(testAppSecret)) + testApp.SetSecret(secret) + testAppID := testApp.ID(secret) + testCaseAPIHandler[TokenRequest, any]{ + name: "no appID request", + method: http.MethodPost, + endpoint: TokensPath, + header: http.Header{ + appSecretHeader: []string{testAppSecret}, + }, + request: &TokenRequest{ + Email: testUserEmail, + }, + response: nil, + err: InvalidAppHeadersErr, + }.Run(t) + testCaseAPIHandler[TokenRequest, any]{ + name: "invalid app id request", + method: http.MethodPost, + endpoint: TokensPath, + header: http.Header{ + appIDHeader: []string{"invalid"}, + appSecretHeader: []string{testAppSecret}, + }, + request: &TokenRequest{ + Email: testUserEmail, + }, + response: nil, + err: InvalidAppIDErr, + }.Run(t) + testCaseAPIHandler[TokenRequest, any]{ + name: "no app secret request", + method: http.MethodPost, + endpoint: TokensPath, + header: http.Header{ + appIDHeader: []string{testAppID.String()}, + }, + request: &TokenRequest{ + Email: testUserEmail, + }, + response: nil, + err: InvalidAppHeadersErr, + }.Run(t) + testCaseAPIHandler[TokenRequest, any]{ + name: "no email provided", + method: http.MethodPost, + endpoint: TokensPath, + header: http.Header{ + appIDHeader: []string{testAppID.String()}, + appSecretHeader: []string{testAppSecret}, + }, + request: &TokenRequest{ + Email: "", + }, + response: nil, + err: GenerateTokenErr, + }.Run(t) + invalid := []byte("invalid") + testCaseAPIHandler[[]byte, any]{ + name: "no request", + method: http.MethodPost, + endpoint: TokensPath, + header: http.Header{ + appIDHeader: []string{testAppID.String()}, + appSecretHeader: []string{testAppSecret}, + }, + request: &invalid, + response: nil, + err: DecodeTokenRequestErr, + }.Run(t) + + login.Template = email.EmailTemplate{ + HTML: "", + Plain: `\[{{.Token}}]`, + } + testCaseAPIHandler[TokenRequest, any]{ + name: "valid request", + method: http.MethodPost, + endpoint: TokensPath, + header: http.Header{ + appIDHeader: []string{testAppID.String()}, + appSecretHeader: []string{testAppSecret}, + }, + request: &TokenRequest{ + Email: testUserEmail, + }, + response: nil, + }.Run(t) + + var testToken *token.Token + select { + case receivedMsg := <-inboxChan: + testToken = login.FindToken(testUserEmail, receivedMsg) + if testToken == nil { + t.Fatal("could not find token in email") + } + break + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for the email to be received") + } + + testCaseAPIHandler[TokenStatusRequest, TokenStatusResponse]{ + name: "valid token status request", + method: http.MethodPut, + endpoint: TokensPath, + header: http.Header{ + appIDHeader: []string{testAppID.String()}, + appSecretHeader: []string{testAppSecret}, + }, + request: &TokenStatusRequest{ + Token: testToken.String(), + Email: testUserEmail, + }, + response: &TokenStatusResponse{ + Valid: true, + Expiration: testToken.Expiration().Time(), + }, + }.Run(t) + + testCaseAPIHandler[TokenStatusRequest, any]{ + name: "invalid app id", + method: http.MethodPut, + endpoint: TokensPath, + header: http.Header{ + appIDHeader: []string{"invalid"}, + appSecretHeader: []string{testAppSecret}, + }, + request: &TokenStatusRequest{ + Token: testToken.String(), + Email: testUserEmail, + }, + response: nil, + err: InvalidAppIDErr, + }.Run(t) + + testCaseAPIHandler[TokenStatusRequest, any]{ + name: "no app secret", + method: http.MethodPut, + endpoint: TokensPath, + header: http.Header{ + appIDHeader: []string{testAppID.String()}, + }, + request: &TokenStatusRequest{ + Token: testToken.String(), + Email: testUserEmail, + }, + response: nil, + err: InvalidAppHeadersErr, + }.Run(t) + + testCaseAPIHandler[TokenStatusRequest, any]{ + name: "no headers", + method: http.MethodPut, + endpoint: TokensPath, + request: &TokenStatusRequest{ + Token: testToken.String(), + Email: testUserEmail, + }, + response: nil, + err: InvalidAppHeadersErr, + }.Run(t) + + testCaseAPIHandler[any, any]{ + name: "no request", + method: http.MethodPut, + endpoint: TokensPath, + header: http.Header{ + appIDHeader: []string{testAppID.String()}, + appSecretHeader: []string{testAppSecret}, + }, + err: DecodeTokenStatusRequestErr, + }.Run(t) +} diff --git a/api/helpers.go b/api/helpers.go new file mode 100644 index 0000000..3511b7d --- /dev/null +++ b/api/helpers.go @@ -0,0 +1,25 @@ +package api + +import ( + "fmt" + "net/http" +) + +// appConfigFromRequest extracts the app id and app secret from the request +// headers. It returns an error if the app id or app secret is missing. The +// app id and app secret are used to authenticate the app making the request. +// The app id is a unique identifier for the app, and the app secret is a +// shared secret used to verify the authenticity of the request for this +// service. +func appConfigFromRequest(r *http.Request) (string, string, error) { + // get the app id from the request header + strAppID := r.Header.Get(appIDHeader) + if strAppID == "" { + return "", "", fmt.Errorf("missing app id") + } + strAppSecret := r.Header.Get(appSecretHeader) + if strAppSecret == "" { + return "", "", fmt.Errorf("missing app secret") + } + return strAppID, strAppSecret, nil +} diff --git a/api/helpers_test.go b/api/helpers_test.go new file mode 100644 index 0000000..8d0c9e8 --- /dev/null +++ b/api/helpers_test.go @@ -0,0 +1,60 @@ +package api + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestAppConfigFromRequest(t *testing.T) { + tests := []struct { + name string + headers map[string]string + expectedAppID string + expectedSecret string + expectError bool + }{ + { + name: "Valid headers", + headers: map[string]string{appIDHeader: "testAppID", appSecretHeader: "testAppSecret"}, + expectedAppID: "testAppID", + expectedSecret: "testAppSecret", + expectError: false, + }, + { + name: "Missing app id", + headers: map[string]string{appSecretHeader: "testAppSecret"}, + expectError: true, + }, + { + name: "Missing app secret", + headers: map[string]string{appIDHeader: "testAppID"}, + expectError: true, + }, + { + name: "Missing both headers", + headers: map[string]string{}, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + for key, value := range tt.headers { + req.Header.Set(key, value) + } + + appID, appSecret, err := appConfigFromRequest(req) + if (err != nil) != tt.expectError { + t.Errorf("expected error: %v, got: %v", tt.expectError, err) + } + if appID != tt.expectedAppID { + t.Errorf("expected appID: %s, got: %s", tt.expectedAppID, appID) + } + if appSecret != tt.expectedSecret { + t.Errorf("expected appSecret: %s, got: %s", tt.expectedSecret, appSecret) + } + }) + } +} diff --git a/api/io/error.go b/api/io/error.go new file mode 100644 index 0000000..d5b77c6 --- /dev/null +++ b/api/io/error.go @@ -0,0 +1,88 @@ +package io + +import ( + "encoding/json" + "fmt" + "net/http" +) + +// APIError represents an error response from the API. It includes a code, +// message, and optional error string. The StatusCode field is used to set the +// HTTP status code for the response. It has some methods to manipulate the +// error message and write the error response to an http.ResponseWriter. +type APIError struct { + Code int `json:"code"` + Message string `json:"message"` + Err string `json:"error,omitempty"` + StatusCode int `json:"-"` +} + +// Error implements the error interface for APIError. It returns a string +// representation of the error, including the code, message, error string, +// and status code. It can be used for logging or debugging purposes. +func (e *APIError) Error() string { + return fmt.Sprintf("code: %d, message: %s, error: %s, status_code: %d", e.Code, e.Message, e.Err, e.StatusCode) +} + +// WithErr appends an error message to the existing error string in the +// APIError. If the existing error string is empty, it sets it to the new +// error message. This method is useful for chaining error messages together +// for better debugging and logging. It returns the updated APIError instance +// and also updates the current instance. +func (e *APIError) WithErr(err error) *APIError { + if e.Err == "" { + e.Err = err.Error() + return e + } + e.Err = fmt.Sprintf("%s: %s", e.Err, err.Error()) + return e +} + +// With appends a string message to the existing message in the APIError. If +// the existing message is empty, it sets it to the new message. This method +// is useful for chaining messages together for better debugging and logging. +// It returns the updated APIError instance and also updates the current +// instance. +func (e *APIError) With(msg string) *APIError { + if e.Message == "" { + e.Message = msg + return e + } + e.Message = fmt.Sprintf("%s: %s", e.Message, msg) + return e +} + +// WriteJSON writes the APIError as a JSON response to the provided +// http.ResponseWriter. It sets the Content-Type header to "application/json" +// and writes the status code and serialized error bytes to the response. +// If an error occurs during writing, it writes an internal server error +// response instead. +func (e *APIError) Write(w http.ResponseWriter) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(e.StatusCode) + if _, err := w.Write(e.bytes()); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +// bytes serializes the APIError to JSON bytes. If an error occurs during +// serialization, it returns nil. +func (e *APIError) bytes() []byte { + bErr, err := json.Marshal(e) + if err != nil { + return nil + } + return bErr +} + +// NewAPIError creates a new APIError instance with the provided code and +// status code. It initializes the error string and message to empty strings. +// This function is useful for creating a new APIError instance with +// specific error codes and status codes. It returns a pointer to the +// newly created APIError instance. +func NewAPIError(code, status int) *APIError { + return &APIError{ + Code: code, + StatusCode: status, + } +} diff --git a/api/io/error_test.go b/api/io/error_test.go new file mode 100644 index 0000000..4da371f --- /dev/null +++ b/api/io/error_test.go @@ -0,0 +1,90 @@ +package io + +import ( + "errors" + "net/http" + "net/http/httptest" + "testing" +) + +func TestNewAPIError(t *testing.T) { + err := NewAPIError(1001, http.StatusBadRequest) + if err.Code != 1001 { + t.Errorf("expected code 1001, got %d", err.Code) + } + if err.StatusCode != http.StatusBadRequest { + t.Errorf("expected status code %d, got %d", http.StatusBadRequest, err.StatusCode) + } + if err.Message != "" { + t.Errorf("expected empty message, got %s", err.Message) + } + if err.Err != "" { + t.Errorf("expected empty error string, got %s", err.Err) + } +} + +func TestAPIError_Error(t *testing.T) { + err := NewAPIError(1001, http.StatusBadRequest) + err.Message = "Bad Request" + err.Err = "Invalid input" + expected := "code: 1001, message: Bad Request, error: Invalid input, status_code: 400" + if err.Error() != expected { + t.Errorf("expected %s, got %s", expected, err.Error()) + } +} + +func TestAPIError_WithErr(t *testing.T) { + err := NewAPIError(1001, http.StatusBadRequest) + _ = err.WithErr(errors.New("Invalid input")) + if err.Err != "Invalid input" { + t.Errorf("expected error string 'Invalid input', got %s", err.Err) + } + + _ = err.WithErr(errors.New("Missing field")) + if err.Err != "Invalid input: Missing field" { + t.Errorf("expected error string 'Invalid input: Missing field', got %s", err.Err) + } +} + +func TestAPIError_With(t *testing.T) { + err := NewAPIError(1001, http.StatusBadRequest) + _ = err.With("Bad Request") + if err.Message != "Bad Request" { + t.Errorf("expected message 'Bad Request', got %s", err.Message) + } + + _ = err.With("Invalid input") + if err.Message != "Bad Request: Invalid input" { + t.Errorf("expected message 'Bad Request: Invalid input', got %s", err.Message) + } +} + +func TestAPIError_Write(t *testing.T) { + err := NewAPIError(1001, http.StatusBadRequest) + err.Message = "Bad Request" + err.Err = "Invalid input" + + rr := httptest.NewRecorder() + err.Write(rr) + + if status := rr.Code; status != http.StatusBadRequest { + t.Errorf("expected status code %d, got %d", http.StatusBadRequest, status) + } + + expected := `{"code":1001,"message":"Bad Request","error":"Invalid input"}` + if rr.Body.String() != expected { + t.Errorf("expected body %s, got %s", expected, rr.Body.String()) + } +} + +func TestAPIError_bytes(t *testing.T) { + err := NewAPIError(1001, http.StatusBadRequest) + err.Message = "Bad Request" + err.Err = "Invalid input" + + data := err.bytes() + expected := `{"code":1001,"message":"Bad Request","error":"Invalid input"}` + if string(data) != expected { + t.Errorf("expected %s, got %s", expected, string(data)) + } +} diff --git a/api/io/req.go b/api/io/req.go new file mode 100644 index 0000000..0b9691d --- /dev/null +++ b/api/io/req.go @@ -0,0 +1,39 @@ +package io + +import ( + "encoding/json" + "fmt" + "io" + "net/http" +) + +// Request represents a request with a generic data type to be unmarshalled +// from the request body. It implements the Read method to read and +// unmarshal the request body into the Data field. +type Request[T any] struct { + Data T +} + +// Read reads the request body and unmarshals it into the Data field of the +// Request struct. It returns an error if the request body is nil or empty, +// or if there is an error during unmarshalling. If the Request struct is nil, +// it initializes a new instance of Request[T]. This method is useful for +// handling incoming requests in a generic way, allowing for different data +// types to be processed without needing to define separate request structs +// for each type. +func (req *Request[T]) Read(r *http.Request) error { + if req == nil { + req = new(Request[T]) + } + if r.Body == nil { + return fmt.Errorf("nil request body") + } + rawBody, err := io.ReadAll(r.Body) + if err != nil { + return fmt.Errorf("failed to read request body: %w", err) + } + if len(rawBody) == 0 { + return fmt.Errorf("empty request body") + } + return json.Unmarshal(rawBody, &req.Data) +} diff --git a/api/io/req_test.go b/api/io/req_test.go new file mode 100644 index 0000000..98b9790 --- /dev/null +++ b/api/io/req_test.go @@ -0,0 +1,57 @@ +package io + +import ( + "bytes" + "encoding/json" + "net/http" + "testing" +) + +func TestRead(t *testing.T) { + type Data struct { + Message string `json:"message"` + } + data := &Data{Message: "Hello, World!"} + body, _ := json.Marshal(data) + req, err := http.NewRequest("POST", "/", bytes.NewBuffer(body)) + if err != nil { + t.Fatal(err) + } + + var request Request[Data] + if err := request.Read(req); err != nil { + t.Errorf("unexpected error: %v", err) + } + if request.Data.Message != data.Message { + t.Errorf("expected %s, got %s", data.Message, request.Data.Message) + } +} + +func TestRead_EmptyBody(t *testing.T) { + noBody, err := http.NewRequest("POST", "/", nil) + if err != nil { + t.Fatal(err) + } + + if nilReq := new(Request[any]).Read(noBody); nilReq == nil { + t.Errorf("expected error, got nil") + } + + req, err := http.NewRequest("POST", "/", bytes.NewBuffer([]byte(""))) + if err != nil { + t.Fatal(err) + } + + var request *Request[any] + err = request.Read(req) + if err == nil { + t.Errorf("expected error, got nil") + } + err = new(Request[any]).Read(req) + if err == nil { + t.Errorf("expected error, got nil") + } + if err.Error() != "empty request body" { + t.Errorf("expected empty request body error, got %v", err) + } +} diff --git a/api/io/res.go b/api/io/res.go new file mode 100644 index 0000000..a73cc8f --- /dev/null +++ b/api/io/res.go @@ -0,0 +1,113 @@ +package io + +import ( + "encoding/json" + "net/http" +) + +// Response represents a response with a generic data type. It can be used to +// send JSON responses or plain text responses. The Data field holds the +// response data, and the empty field indicates whether the response is empty +// or not. It has methods to write the response to an http.ResponseWriter in +// JSON format or plain text format. +type Response[T any] struct { + Data T + empty bool +} + +// ResponseWith creates a new Response instance with the provided data. If +// the data is nil, it returns an empty response. This method is useful for +// creating responses with different data types without needing to define +// separate response structs for each type. It returns a pointer to the +// Response instance. +func ResponseWith[T any](data *T) *Response[T] { + if data == nil { + return &Response[T]{empty: true} + } + return &Response[T]{ + Data: *data, + empty: false, + } +} + +// OkResponse creates a new Response instance with the provided byte slice. +// If the byte slice is empty, it returns an empty response. This method is +// useful for creating responses with raw byte data. It returns a pointer to +// the Response instance. The empty field indicates whether the response is +// empty or not. If the byte slice is empty, the response is considered empty. +// If the byte slice is not empty, the response is considered non-empty. +func OkResponse(body ...byte) *Response[any] { + if len(body) > 0 { + return &Response[any]{Data: body, empty: false} + } + return &Response[any]{empty: true} +} + +// WriteJSON writes the response data to the provided http.ResponseWriter in +// JSON format. It sets the Content-Type header to "application/json" and +// writes the response data as JSON. If the response is empty, it writes a +// plain text "OK" response with a 200 OK status code. If there is an error +// during JSON encoding or response writing, it writes an error response with +// a 500 Internal Server Error status code. This method is useful for sending +// JSON responses to the client. It can be used in HTTP handlers or middleware +// to send structured JSON responses. +func (r *Response[T]) WriteJSON(w http.ResponseWriter) { + if !r.empty { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(r.Data); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + return + } + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte(http.StatusText(http.StatusOK))); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +// Write writes the response data to the provided http.ResponseWriter in +// plain text format. It sets the Content-Type header to "text/plain" and +// writes the response data as plain text. If the response is empty, it writes +// a plain text "OK" response with a 200 OK status code. If there is an error +// during response writing, it writes an error response with a 500 Internal +// Server Error status code. This method is useful for sending plain text +// responses to the client. It can be used in HTTP handlers or middleware to +// send simple text responses. It is a more generic method than WriteJSON, as +// it does not require the response data to be JSON-serializable. +func (r *Response[T]) Write(w http.ResponseWriter) { + w.WriteHeader(http.StatusOK) + if r.empty { + if _, err := w.Write([]byte(http.StatusText(http.StatusOK))); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + return + } + if data, ok := r.bytes(); ok { + if _, err := w.Write(data); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + } +} + +// bytes returns the response data as a byte slice. If the response is empty, +// it returns nil and a boolean indicating that the response is empty. If the +// response data is not a byte slice, it returns nil and a boolean indicating +// that the response data is not a byte slice. This method is useful for +// converting the response data to a byte slice for writing to the response +// writer or for further processing. +func (r *Response[T]) bytes() ([]byte, bool) { + // check if the response is empty + if r.empty { + return nil, true + } + // ensure that the response data is an slice of bytes + switch v := any(r.Data).(type) { + case []byte: + return v, true + case string: + return []byte(v), true + default: + return nil, false + } +} diff --git a/api/io/res_test.go b/api/io/res_test.go new file mode 100644 index 0000000..fc6925d --- /dev/null +++ b/api/io/res_test.go @@ -0,0 +1,179 @@ +package io + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestResponseWith(t *testing.T) { + type Data struct { + Message string `json:"message"` + } + nilResp := ResponseWith[Data](nil) + if !nilResp.empty { + t.Errorf("expected response to be empty") + } + data := &Data{Message: "Hello, World!"} + resp := ResponseWith(data) + if resp.empty { + t.Errorf("expected response to be non-empty") + } + if resp.Data.Message != data.Message { + t.Errorf("expected %s, got %s", data.Message, resp.Data.Message) + } +} + +func TestOkResponse(t *testing.T) { + resp := OkResponse() + if !resp.empty { + t.Errorf("expected response to be empty") + } +} + +func TestWriteJSON(t *testing.T) { + type Data struct { + Message string `json:"message"` + } + data := &Data{Message: "Hello, World!"} + resp := ResponseWith(data) + + rr := httptest.NewRecorder() + resp.WriteJSON(rr) + + if status := rr.Code; status != http.StatusOK { + t.Errorf("expected status code %d, got %d", http.StatusOK, status) + } + + expected, _ := json.Marshal(data) + if rr.Body.String() != string(expected)+"\n" { // Ensure newline is accounted for + t.Errorf("expected body %s, got %s", string(expected)+"\n", rr.Body.String()) + } + // write json with nil data + nilResp := ResponseWith[string](nil) + rr = httptest.NewRecorder() + nilResp.WriteJSON(rr) + if status := rr.Code; status != http.StatusOK { + t.Errorf("expected status code %d, got %d", http.StatusOK, status) + } +} + +func TestWriteJSON_Empty(t *testing.T) { + resp := OkResponse() + + rr := httptest.NewRecorder() + resp.WriteJSON(rr) + + if status := rr.Code; status != http.StatusOK { + t.Errorf("expected status code %d, got %d", http.StatusOK, status) + } + + if rr.Body.String() != "OK" { + t.Errorf("expected body OK, got %s", rr.Body.String()) + } +} + +func TestOkResponse_WithBody(t *testing.T) { + body := []byte("Hello, World!") + resp := OkResponse(body...) + + if resp.empty { + t.Errorf("expected response to be non-empty") + } + + if string(resp.Data.([]byte)) != string(body) { + t.Errorf("expected body %s, got %s", string(body), string(resp.Data.([]byte))) + } +} + +func TestWrite(t *testing.T) { + data := []byte("Hello, World!") + resp := ResponseWith(&data) + + rr := httptest.NewRecorder() + resp.Write(rr) + + if status := rr.Code; status != http.StatusOK { + t.Errorf("expected status code %d, got %d", http.StatusOK, status) + } + + // Adjust to match the expected plain text output + if rr.Body.String() != string(data) { + t.Errorf("expected body %s, got %s", string(data), rr.Body.String()) + } +} + +func TestWrite_Empty(t *testing.T) { + resp := OkResponse() + + rr := httptest.NewRecorder() + resp.Write(rr) + + if status := rr.Code; status != http.StatusOK { + t.Errorf("expected status code %d, got %d", http.StatusOK, status) + } + + if rr.Body.String() != "OK" { + t.Errorf("expected body OK, got %s", rr.Body.String()) + } +} + +func TestBytes(t *testing.T) { + body := []byte("Hello, World!") + resp := OkResponse(body...) + + data, ok := resp.bytes() + if !ok { + t.Errorf("expected bytes to be valid") + } + + if string(data) != string(body) { + t.Errorf("expected body %s, got %s", string(body), string(data)) + } +} + +func TestBytes_Empty(t *testing.T) { + resp := OkResponse() + + data, ok := resp.bytes() + if !ok { + t.Errorf("expected bytes to be valid") + } + if data != nil { + t.Errorf("expected nil data for empty response, got %v", data) + } + + // custom text response + bmsg := []byte("Hello, World!") + bresp := ResponseWith(&bmsg) + bdata, ok := bresp.bytes() + if !ok { + t.Errorf("expected bytes to be valid") + } + if string(bdata) != string(bmsg) { + t.Errorf("expected body %s, got %s", string(bmsg), string(bdata)) + } + + // string type + msg := "Hello, World!" + sresp := ResponseWith(&msg) + sdata, ok := sresp.bytes() + if !ok { + t.Errorf("expected bytes to be valid") + } + if string(sdata) != msg { + t.Errorf("expected body %s, got %s", msg, string(sdata)) + } + + // invalid type + imsg := 123 + iresp := ResponseWith(&imsg) + idata, ok := iresp.bytes() + if ok { + t.Errorf("expected bytes to be invalid") + } + if idata != nil { + t.Errorf("expected nil data for invalid response, got %v", idata) + } +} diff --git a/api/routes.go b/api/routes.go new file mode 100644 index 0000000..78631ae --- /dev/null +++ b/api/routes.go @@ -0,0 +1,26 @@ +package api + +// routes paths constants +const ( + // HealthCheckPath constant is the path used to check the health of the API + // server. It is a string with a value of "/health". + HealthCheckPath = "/ping" + // AppsPath constant is the path used to create the apps in the API server. + AppsPath = "/apps" + // TokensPath constant is the path used to generate and verify the tokens + // in the API server. + TokensPath = "/tokens" + // DemoInboxPath constant is the path used to get the demo email inbox + // in the API server when it runs in demo mode. + DemoInboxPath = "/demo/inbox" +) + +// other api related constants +const ( + // appIDHeader constant is the header of the app ID in the request. It is + // used to authenticate the app making the request. + appIDHeader = "AppID" + // appSecretHeader constant is the header of the app secret in the request + // It is used to authenticate the app making the request. + appSecretHeader = "AppSecret" +) diff --git a/api/service.go b/api/service.go index ebe696d..acd4b30 100644 --- a/api/service.go +++ b/api/service.go @@ -3,7 +3,6 @@ package api import ( "context" "fmt" - "log" "net/http" "os" "os/signal" @@ -12,77 +11,87 @@ import ( "time" "github.com/lucasmenendez/apihandler" - "github.com/simpleauthlink/authapi/db" - "github.com/simpleauthlink/authapi/email" - "github.com/simpleauthlink/authapi/helpers" + "github.com/simpleauthlink/authapi/internal/fakesmtpserver" + "github.com/simpleauthlink/authapi/notification" ) -// Config struct represents the configuration needed to init the service. It -// includes the email configuration, the server hostname, the server port, the -// data path to store the database, and the cleaner cooldown to clean the -// expired tokens. +// Config struct represents the configuration for the API service. It contains +// the server address, server port, and secret key for the service. The server +// address is the address where the service will listen for incoming requests, +// and the server port is the port number where the service will listen for +// incoming requests. The secret key is used to sign and verify tokens. The +// demo mode is used to enable or disable the demo functionality of the service. +// The demo SMTP address and port are used to configure the demo mail server. +// The demo mode is used to enable or disable the demo functionality of the +// service. type Config struct { - email.EmailConfig - Server string - ServerPort int - CleanerCooldown time.Duration + Server string + ServerPort int + Secret string + // demo stuff + DemoMode bool + DemoSMTPAddr string + DemoSMTPPort int } -// Service struct represents the service that is going to be started. It -// includes the context and the cancel function to stop the service, the wait -// group to wait for the background processes to finish, the configuration, -// the database connection and the api handler. +// Service struct represents the API service. It contains the context, cancel +// function, wait group, configuration, notification queue, API handler, HTTP +// server, and demo mail server. The context is used to manage the lifecycle +// of the service, while the wait group is used to wait for background processes +// to finish. The notification queue is used to send notifications, and the API +// handler is used to handle incoming requests. The HTTP server is used to +// serve the API endpoints, and the demo mail server is used to simulate +// sending emails in demo mode. +// The demo mail server is a fake SMTP server that captures emails sent to it +// for testing purposes. The demo mail inbox is a channel that receives the +// captured emails. type Service struct { ctx context.Context cancel context.CancelFunc wait sync.WaitGroup cfg *Config - db db.DB - emailQueue *email.EmailQueue + nq notification.Queue handler *apihandler.Handler httpServer *http.Server + // demo stuff + demoMailServer *fakesmtpserver.FakeSMTPServer + demoMailInbox chan string } -// New function creates a new service based on the provided context, the db -// interface and configuration. It initializes the email queue, creates the -// service and sets the api handlers. If something goes wrong during the -// process, it returns an error. -func New(ctx context.Context, db db.DB, cfg *Config) (*Service, error) { +// New function creates a new service instance. It takes a context, a config +// struct, and a notification queue as parameters. It returns a pointer to the +// service instance and an error if something goes wrong during the process. +// The function is responsible for setting up the service and its dependencies. +// It handles the configuration, rate limiting, and HTTP server setup. It also +// manages the demo mode functionality, including the demo mail server and +// inbox. The function is designed to be used as a constructor for the service +// and is responsible for initializing all the necessary components for the +// service to function properly. +func New(ctx context.Context, cfg *Config, nq notification.Queue) (*Service, error) { internalCtx, cancel := context.WithCancel(ctx) - emailQueue, err := email.NewEmailQueue(internalCtx, &cfg.EmailConfig) - if err != nil { - if emailQueue == nil { - cancel() - return nil, err - } - log.Println("WRN: something occurs during email queue creation:", err) - } // create the service srv := &Service{ - ctx: internalCtx, - cancel: cancel, - cfg: cfg, - db: db, - emailQueue: emailQueue, - handler: apihandler.NewHandler(&apihandler.Config{ - CORS: true, - RateLimitConfig: &apihandler.RateLimitConfig{ - Rate: 2, - Limit: 10, - }, - }), + ctx: internalCtx, + cancel: cancel, + cfg: cfg, + nq: nq, + handler: apihandler.NewHandler(true, nil), } - srv.handler.Get(helpers.HealthCheckPath, func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - }) - // user handlers - srv.handler.Post(helpers.UserEndpointPath, srv.userTokenHandler) - srv.handler.Get(helpers.UserEndpointPath, srv.validateUserTokenHandler) - // app handlers - srv.handler.Get(helpers.AppEndpointPath, srv.appHandler) - srv.handler.Post(helpers.AppEndpointPath, srv.appTokenHandler) - srv.handler.Put(helpers.AppEndpointPath, srv.updateAppHandler) - srv.handler.Delete(helpers.AppEndpointPath, srv.delAppHandler) + // demo stuff + if cfg.DemoMode { + srv.demoMailInbox = make(chan string, 1) + srv.demoMailServer = fakesmtpserver.NewServer(cfg.DemoSMTPAddr, + cfg.DemoSMTPPort, srv.demoMailInbox) + if err := srv.demoMailServer.Start(internalCtx); err != nil { + return nil, err + } + _ = srv.handler.Get(DemoInboxPath, srv.demoInboxHandler) + } + // register the routes and handlers + _ = srv.handler.Post(AppsPath, srv.generateAppIDHandler) + _ = srv.handler.Post(TokensPath, srv.requestTokenHandler) + _ = srv.handler.Put(TokensPath, srv.verifyTokenHandler) + _ = srv.handler.Get(HealthCheckPath, srv.healthCheckHandler) // build the http server srv.httpServer = &http.Server{ Addr: fmt.Sprintf("%s:%d", cfg.Server, cfg.ServerPort), @@ -91,13 +100,8 @@ func New(ctx context.Context, db db.DB, cfg *Config) (*Service, error) { return srv, nil } -// Start method starts the service. It starts the token cleaner and the api -// server. If something goes wrong during the process, it returns an error. +// Start method starts the service by starting the http server. func (s *Service) Start() error { - // start the email queue - s.emailQueue.Start() - // start the token cleaner in the background - s.sanityTokenCleaner() // start the api server if err := s.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { return err @@ -106,19 +110,29 @@ func (s *Service) Start() error { } // Stop method stops the service. It cancels the context and waits for the -// background processes to finish. It closes the database. If something goes -// wrong during the process, it returns an error. -func (s *Service) Stop() error { - // close the database - if err := s.db.Close(); err != nil { - return fmt.Errorf("error closing db: %w", err) - } - // stop the email queue - s.emailQueue.Stop() +// background processes to finish. It also closes the http server and the +// demo mail server if it is running. +func (s *Service) Stop() { // cancel the context and wait for the background processes finish s.cancel() defer s.wait.Wait() - return nil +} + +// Ping method checks if the service is up and running. It sends a GET request +// to the health check endpoint and returns true if the response status code +// is 200 OK, otherwise it returns false. If something goes wrong during the +// process, it returns false. +func (s *Service) Ping() bool { + url := fmt.Sprintf("http://%s:%d%s", s.cfg.Server, s.cfg.ServerPort, HealthCheckPath) + request, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return false + } + response, err := http.DefaultClient.Do(request) + if err != nil { + return false + } + return response.StatusCode == http.StatusOK } // WaitToShutdown method waits for the service to shutdown. It listens for the @@ -128,12 +142,8 @@ func (s *Service) WaitToShutdown() error { done := make(chan os.Signal, 1) signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) <-done - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - defer func() { - if err := s.Stop(); err != nil { - log.Println(err) - } - }() + defer s.Stop() return s.httpServer.Shutdown(ctx) } diff --git a/api/service_test.go b/api/service_test.go index c36056e..4703a1e 100644 --- a/api/service_test.go +++ b/api/service_test.go @@ -2,36 +2,84 @@ package api import ( "context" + "fmt" + "os" "testing" "time" - "github.com/simpleauthlink/authapi/db" - "github.com/simpleauthlink/authapi/email" + "github.com/simpleauthlink/authapi/internal/fakesmtpserver" + "github.com/simpleauthlink/authapi/notification/email" ) -func TestNew(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() +const ( + testServerAddr = "127.0.0.1" + testServerSMTPPort = 2526 + testServerAPIPort = 5555 + testServerSecret = "server-secret" + testSenderName = "TestAPI" + testSender = "api-test@testmail.com" + testUserEmail = "user@testmail.com" + testAppName = "TestApp" + testAppRedirectURL = "http://testapp.com" + testAppSessionDuration = time.Second * 35 + testAppSecret = "test-secret" +) + +var ( + testServerApiURL = fmt.Sprintf("http://%s:%d", testServerAddr, testServerAPIPort) + inboxChan = make(chan string, 1) +) - testDB := new(db.TempDriver) - testDB.Init(nil) - srv, err := New(ctx, testDB, &Config{ - Server: "localhost", - ServerPort: 8080, - CleanerCooldown: 30 * time.Second, - EmailConfig: email.EmailConfig{ - EmailHost: "smtp.gmail.com", - EmailPort: 587, - Address: "", - Password: "", - }, +func TestMain(m *testing.M) { + defer close(inboxChan) + // create context with cancel + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // start test SMTP server to receive the email + testSrv := fakesmtpserver.NewServer(testServerAddr, testServerSMTPPort, inboxChan) + if err := testSrv.Start(ctx); err != nil { + panic(err) + } + defer testSrv.Stop() + // create email queue with valid config + eq, err := email.NewEmailQueue(ctx, &email.EmailConfig{ + SMTPServer: testServerAddr, + SMTPPort: testServerSMTPPort, + FromName: testSenderName, + FromAddress: testSender, }) if err != nil { - t.Errorf("expected nil, got %v", err) - return + panic(err) + } + eq.Start() + defer eq.Stop() + // create the API service + apiSrv, err := New(ctx, &Config{ + Server: testServerAddr, + ServerPort: testServerAPIPort, + Secret: testServerSecret, + }, eq) + if err != nil { + panic(err) } - if srv == nil { - t.Errorf("expected not nil, got nil") - return + go func() { + if err := apiSrv.Start(); err != nil { + panic(err) + } + }() + defer apiSrv.Stop() + // make ping to the server to check if it is running + nRetries := 5 + for { + if nRetries == 0 { + panic("API server is not running") + } + if ok := apiSrv.Ping(); ok { + break + } + nRetries-- + time.Sleep(time.Second) } + // run the tests + os.Exit(m.Run()) } diff --git a/api/tokens.go b/api/tokens.go deleted file mode 100644 index 2f0f155..0000000 --- a/api/tokens.go +++ /dev/null @@ -1,179 +0,0 @@ -package api - -import ( - "fmt" - "log" - "net/url" - "strings" - "time" - - "github.com/simpleauthlink/authapi/db" - "github.com/simpleauthlink/authapi/helpers" -) - -// magicLink function generates and returns a magic link, the generated token -// and the associated app name, based on the provided app secret and the user -// email. If the secret or the email are empty, it returns an error. It gets -// the app id from the database based on the secret. It generates a token and -// calculates the expiration time based on the app session duration. It stores -// the token and the expiration time in the database. It returns the magic link -// composed of the app callback and the generated token. -func (s *Service) magicLink(rawSecret, email, redirectURL string, duration uint64) (string, string, string, error) { - // check if the secret and email are not empty - if len(rawSecret) == 0 || len(email) == 0 { - return "", "", "", fmt.Errorf("secret and email are required") - } - // get app secret from raw secret - appSecret, err := helpers.Hash(rawSecret, helpers.SecretSize) - if err != nil { - return "", "", "", err - } - // get app and app id from the database based on the secret - app, appId, err := s.db.AppBySecret(appSecret) - if err != nil { - return "", "", "", err - } - // get the number of tokens for the app using the app id as the prefix - numberOfAppTokens, err := s.db.CountTokens(appId) - if err != nil { - return "", "", "", err - } - // check if the number of tokens is greater than the users quota - if numberOfAppTokens >= app.UsersQuota { - return "", "", "", fmt.Errorf("users quota reached") - } - // generate token and calculate expiration - token, userId, err := helpers.EncodeUserToken(appId, email) - if err != nil { - return "", "", "", err - } - // by default, the session duration is the app session duration but it can - // be overwritten by the request - sessionDuration := app.SessionDuration - if duration > 0 { - sessionDuration = duration - } - expiration := time.Now().Add(time.Duration(sessionDuration) * time.Second) - // check if there is a token for the user and app in the database and delete - // it if it exists - tokenPrefix := strings.Join([]string{appId, userId}, helpers.TokenSeparator) - if err := s.db.DeleteTokensByPrefix(tokenPrefix); err != nil { - if err != db.ErrTokenNotFound { - log.Println("ERR: error checking token:", err) - } - } - // set token and expiration in the database - if err := s.db.SetToken(db.Token(token), expiration); err != nil { - return "", "", "", err - } - // return the magic link based on the app callback and the generated token - // by default, the redirect URL is the app redirect URL but it can be - // overwritten by the request - baseRawURL := app.RedirectURL - if redirectURL != "" { - baseRawURL = redirectURL - } - baseURL, err := url.Parse(baseRawURL) - if err != nil { - return "", "", "", fmt.Errorf("invalid redirect URL: %w", err) - } - urlQuery := baseURL.Query() - urlQuery.Set(helpers.TokenQueryParam, token) - baseURL.RawQuery = urlQuery.Encode() - return helpers.SafeURL(baseURL), token, app.Name, nil -} - -// validUserToken function checks if the provided token is valid. It checks if -// the token is not empty, if the app id is in the database, if the token is not -// expired and if the token is in the database. If the token is invalid, it -// returns false. If something goes wrong during the process, it logs the error -// and returns false. If the token is valid, it returns true. -func (s *Service) validUserToken(token, rawSecret string) bool { - // check if the token and secret are not empty - if len(token) == 0 || len(rawSecret) == 0 { - return false - } - // get the app id from the token - appId, _, err := helpers.DecodeUserToken(token) - if err != nil { - return false - } - // check if the secret is valid - if !s.validSecret(appId, rawSecret) { - return false - } - // get the token expiration from the database - expiration, err := s.db.TokenExpiration(db.Token(token)) - if err != nil { - return false - } - // check if the token is expired - if time.Now().After(expiration) { - if err := s.db.DeleteToken(db.Token(token)); err != nil { - log.Println("ERR: error deleting token:", err) - } - return false - } - return true -} - -// validAdminToken function checks if the provided token is a valid admin token. -// It checks if the token is not empty, if the app id is in the database, if the -// token is not expired and if the token is in the database. If the token is -// invalid, it returns false. It also returns the app id if the token is valid. -func (s *Service) validAdminToken(token, rawSecret string) (string, bool) { - // check if the token and secret are not empty - if len(token) == 0 || len(rawSecret) == 0 { - return "", false - } - // get the app id from the token - appId, userId, err := helpers.DecodeUserToken(token) - if err != nil { - return "", false - } - // the app id is composed by the admin user id hash and a nonce, so - // the app id starts with the admin user id, check if so - if !strings.HasPrefix(appId, userId) { - return "", false - } - // check if the secret is valid - if !s.validSecret(appId, rawSecret) { - return "", false - } - // get the token expiration from the database - expiration, err := s.db.TokenExpiration(db.Token(token)) - if err != nil { - return "", false - } - // check if the token is expired - if time.Now().After(expiration) { - if err := s.db.DeleteToken(db.Token(token)); err != nil { - log.Println("ERR: error deleting token:", err) - } - return "", false - } - return appId, true -} - -// sanityTokenCleaner function starts a goroutine that cleans the expired tokens -// from the database every time the cooldown time is reached. It uses a ticker -// to check the cooldown time and a context to stop the goroutine when the -// service is stopped. If something goes wrong during the process, it logs the -// error. -func (s *Service) sanityTokenCleaner() { - s.wait.Add(1) - go func() { - defer s.wait.Done() - ticker := time.NewTicker(s.cfg.CleanerCooldown) - for { - select { - case <-s.ctx.Done(): - return - case <-ticker.C: - if err := s.db.DeleteExpiredTokens(); err != nil { - log.Println("ERR: error deleting expired tokens:", err) - } - } - } - }() -} diff --git a/api/types.go b/api/types.go index 00d6aa9..4daaebd 100644 --- a/api/types.go +++ b/api/types.go @@ -1,27 +1,44 @@ package api -const ( - userTokenSubject = "Here is your magic link for '%s' πŸ”" - appTokenSubject = "Your app '%s' is ready! πŸŽ‰" +import ( + "time" + + "github.com/simpleauthlink/authapi/token" ) -// TokenRequest struct includes the required information by the API service to -// create a token, which is the email of the user. The app secret is also -// required but it is provided in the request headers. -type TokenRequest struct { - Email string `json:"email"` +type AppIDRequest struct { + Name string `json:"name"` + Duration string `json:"session_duration"` RedirectURL string `json:"redirect_url"` - Duration uint64 `json:"session_duration"` + Secret string `json:"secret"` +} + +func (data *AppIDRequest) parseApp() (*token.App, string) { + if duration, err := time.ParseDuration(data.Duration); err == nil { + app := &token.App{ + Name: data.Name, + RedirectURI: data.RedirectURL, + SessionDuration: duration, + } + return app, data.Secret + } + return new(token.App), "" +} + +type AppIDResponse struct { + ID string `json:"id"` +} + +type TokenRequest struct { + Email string `json:"email"` +} + +type TokenStatusRequest struct { + Token string `json:"token"` + Email string `json:"email"` } -// AppData struct includes the required information by the API service to -// create an app, which are the name, the email of the admin, the session -// duration and the callback URL. -type AppData struct { - Name string `json:"name"` - Email string `json:"admin_email"` - Duration uint64 `json:"session_duration"` - RedirectURL string `json:"redirect_url"` - UsersQuota int64 `json:"users_quota"` - CurrentUsers int64 `json:"current_users"` +type TokenStatusResponse struct { + Valid bool `json:"valid"` + Expiration time.Time `json:"expiration"` } diff --git a/assets/app_email_template.html b/assets/app_email_template.html deleted file mode 100644 index 6e77482..0000000 --- a/assets/app_email_template.html +++ /dev/null @@ -1,98 +0,0 @@ - - - - - - - Your app '{{.AppName}}' is ready πŸŽ‰ - - - - - - - - -
- - - - - - - - - - - - - - - - -
- -

SimpleAuth.link

-
- πŸ‘‹ Hi, {{.EmailHandler}}! -

- Your app '{{.AppName}}' has been successfully created βœ…. Here are the details of your app: -

- - - - - - - - - - - - - - - - - -
App ID{{.AppID}}
App Name{{.AppName}}
App Secret{{.Secret}}
Redirect URL{{.RedirectURL}}
-

- Check out the documentation to getting started integrating SimpleAuth with your app πŸš€. -
- - - - -
- πŸ€“ Read the documentation -
-
- ⚠️ Remember to keep your app secret safe and secure. ⚠️ -

- You can always regenerate a new app secret. -
-
- - - \ No newline at end of file diff --git a/client/client.go b/client/client.go deleted file mode 100644 index 71925b5..0000000 --- a/client/client.go +++ /dev/null @@ -1,136 +0,0 @@ -package client - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "net/url" - - "github.com/simpleauthlink/authapi/api" - "github.com/simpleauthlink/authapi/helpers" -) - -// Client struct represents the client to interact with the API server. It -// contains the configuration of the client. The configuration includes the -// secret of the app and the API endpoint. The API endpoint is optional and if -// it is empty, it uses the default API endpoint. The client provides two -// methods to interact with the API server, RequestToken and ValidateToken. -type Client struct { - config *ClientConfig -} - -// New function creates a new client based on the provided configuration. It -// returns the client and an error if the configuration is invalid. The -// configuration must include, at least, the secret of your app. If the API -// endpoint is empty, it uses the default API endpoint. It validates the config -// and returns an error if the configuration is nil, the secret is empty or the -// API endpoint is invalid. -func New(config *ClientConfig) (*Client, error) { - if err := config.check(); err != nil { - return nil, err - } - return &Client{config: config}, nil -} - -// RequestToken function requests a token for the user based on the provided -// email. It returns an error if the email is empty. It receives the context -// and the token request. The token request includes the email of the user, the -// redirect URL and the session duration. The session duration is optional and -// if it is zero, it uses the default session duration. It creates a new URL -// based on the API endpoint, encodes the request, creates the request, sets -// the secret in the header, sets the content type and makes the request. It -// checks the status code and returns an error if the status code is different -// from 200, if so returns an error trying to decode the body of the response. -func (cli *Client) RequestToken(ctx context.Context, req *api.TokenRequest) error { - if req == nil || req.Email == "" { - return fmt.Errorf("email is required to request a token") - } - // create a new URL based on the API endpoint - url := new(url.URL) - *url = *cli.config.url - // set the path - url.Path = helpers.UserEndpointPath - // encode the request - encodedReq, err := json.Marshal(req) - if err != nil { - return fmt.Errorf("error encoding request: %w", err) - } - // create the request - buf := bytes.NewBuffer(encodedReq) - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url.String(), buf) - if err != nil { - return fmt.Errorf("error creating request: %w", err) - } - // set the secret in the header - httpReq.Header.Set(helpers.AppSecretHeader, cli.config.Secret) - // set the content type - httpReq.Header.Set("Content-Type", "application/json") - // make the request - res, err := http.DefaultClient.Do(httpReq) - if err != nil { - return fmt.Errorf("error making request: %w", err) - } - defer res.Body.Close() - // check the status code and return an error if the status code is different - // from 200, if so return an error trying to decode the body of the response - if res.StatusCode != http.StatusOK { - // decode body and return error - msg, err := io.ReadAll(res.Body) - if err != nil { - return fmt.Errorf("unexpected status code: %d", res.StatusCode) - } - return fmt.Errorf("unexpected response: [%d] %s", res.StatusCode, string(msg)) - } - return nil -} - -// ValidateToken function validates the token provided using the API server. It -// returns true if the token is valid, false if the token is invalid, or an -// error if something goes wrong during the process. It receives the context, -// the token and the client configuration. The configuration must include, at -// least, the secret of your app. If the API endpoint is empty, it uses the -// default API endpoint. It validates the config and returns an error if the -// configuration is nil, the secret is empty or the API endpoint is invalid. -func (cli *Client) ValidateToken(ctx context.Context, token string) (bool, error) { - // create a new URL based on the API endpoint - url := new(url.URL) - *url = *cli.config.url - // add token to the query - query := url.Query() - query.Set(helpers.TokenQueryParam, token) - // set the path and query - url.Path = helpers.UserEndpointPath - url.RawQuery = query.Encode() - // create the request - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil) - if err != nil { - return false, fmt.Errorf("error creating request: %w", err) - } - // set the secret in the header - req.Header.Set(helpers.AppSecretHeader, cli.config.Secret) - // make the request - resp, err := http.DefaultClient.Do(req) - if err != nil { - return false, fmt.Errorf("error making request: %w", err) - } - defer resp.Body.Close() - // check the status code, return true if the status code is 200 or false if - // the status code is 401, otherwise return an error trying to decode the - // body of the response - switch resp.StatusCode { - case http.StatusOK: - return true, nil - case http.StatusUnauthorized: - return false, nil - default: - // decode body and return error - msg, err := io.ReadAll(resp.Body) - if err != nil { - return false, fmt.Errorf("unexpected status code: %d", resp.StatusCode) - } - return false, fmt.Errorf("unexpected response: [%d] %s", resp.StatusCode, string(msg)) - } -} diff --git a/client/config.go b/client/config.go deleted file mode 100644 index fc35fef..0000000 --- a/client/config.go +++ /dev/null @@ -1,40 +0,0 @@ -package client - -import ( - "fmt" - "net/url" - - "github.com/simpleauthlink/authapi/helpers" -) - -// ClientConfig struct represents the configuration needed to use the client. -type ClientConfig struct { - // APIEndpoint is the API hostname. - APIEndpoint string - url *url.URL - // Secret is the app secret on the API server. - Secret string -} - -// check function validates the configuration and returns an error if the -// configuration is invalid. It checks if the configuration is nil, if the -// secret is empty, and if the API endpoint is invalid. If the API endpoint is -// empty, it uses the default API endpoint. It returns an error if the -// configuration is nil, the secret is empty or the API endpoint is invalid. -func (conf *ClientConfig) check() error { - if conf == nil { - return fmt.Errorf("config is required") - } - if conf.APIEndpoint == "" { - conf.APIEndpoint = helpers.DefaultAPIEndpoint - } - if conf.Secret == "" { - return fmt.Errorf("secret is required") - } - var err error - conf.url, err = url.Parse(conf.APIEndpoint) - if err != nil { - return fmt.Errorf("invalid API endpoint: %w", err) - } - return nil -} diff --git a/cmd/authapi/main.go b/cmd/authapi/main.go index e1ca2c9..e487ef9 100644 --- a/cmd/authapi/main.go +++ b/cmd/authapi/main.go @@ -2,216 +2,83 @@ package main import ( "context" - "flag" "fmt" "log" - "os" - "strconv" - "time" "github.com/simpleauthlink/authapi/api" - "github.com/simpleauthlink/authapi/db/mongo" - "github.com/simpleauthlink/authapi/email" -) - -const ( - defaultHost = "0.0.0.0" - defaultPort = 8080 - defaultDatabaseURI = "mongodb://admin:password@localhost:27017/" - defaultDatabaseName = "simpleauth" - defaultEmailAddr = "" - defaultEmailPass = "" - defaultEmailHost = "" - defaultEmailPort = 587 - defaultTokenEmailTemplate = "assets/token_email_template.html" - defaultAppEmailTemplate = "assets/app_email_template.html" - defaultDisposableSrcURL = "https://raw.githubusercontent.com/disposable-email-domains/disposable-email-domains/master/disposable_email_blocklist.conf" - - hostFlag = "host" - portFlag = "port" - dbURIFlag = "db-uri" - dbNameFlag = "db-name" - emailAddrFlag = "email-addr" - emailPassFlag = "email-pass" - emailHostFlag = "email-host" - emailPortFlag = "email-port" - tokenEmailTemplateFlag = "email-token-template" - appEmailTemplateFlag = "email-app-template" - disposableSrcFlag = "disposable-src" - hostFlagDesc = "service host" - portFlagDesc = "service port" - dbURIFlagDesc = "database uri" - dbNameFlagDesc = "database name" - emailAddrFlagDesc = "email account address" - emailPassFlagDesc = "email account password" - emailHostFlagDesc = "email server host" - emailPortFlagDesc = "email server port" - tokenEmailTemplateDesc = "path to the html template of new token email" - appEmailTemplateDesc = "path to the html template of new app email" - disposableSrcDesc = "source url of list of disposable emails domains" - - hostEnv = "SIMPLEAUTH_HOST" - portEnv = "SIMPLEAUTH_PORT" - dbURIEnv = "SIMPLEAUTH_DB_URI" - dbNameEnv = "SIMPLEAUTH_DB_NAME" - emailAddrEnv = "SIMPLEAUTH_EMAIL_ADDR" - emailPassEnv = "SIMPLEAUTH_EMAIL_PASS" - emailHostEnv = "SIMPLEAUTH_EMAIL_HOST" - emailPortEnv = "SIMPLEAUTH_EMAIL_PORT" - tokenEmailTemplateEnv = "SIMPLEAUTH_TOKEN_EMAIL_TEMPLATE" - appEmailTemplateEnv = "SIMPLEAUTH_APP_EMAIL_TEMPLATE" - disposableSrcEnv = "SIMPLEAUTH_DISPOSABLE_SRC" + "github.com/simpleauthlink/authapi/cmd" + "github.com/simpleauthlink/authapi/internal/osflag" + "github.com/simpleauthlink/authapi/notification/email" ) type config struct { - host string - port int - dbURI string - dbName string - emailAddr string - emailPass string - emailHost string - emailPort int - tokenEmailTemplate string - appEmailTemplate string - disposableSrc string + host string + port int + emailAddr string + emailUser string + emailPass string + emailHost string + emailPort int + secret string +} + +func (c *config) String() string { + return fmt.Sprintf(`{"server": "%s:%d", "smtpServer": "%s:%d", "smtpAuth": "%s:%s", "secret": "%s"}`, + c.host, c.port, c.emailHost, c.emailPort, c.emailAddr, c.emailPass, c.secret) } func main() { log.SetFlags(log.LstdFlags | log.Lshortfile) - c, err := parseConfig() + c := new(config) + // get config from flags + osflag.StringVar(&c.host, cmd.HostEnv, cmd.HostFlag, cmd.DefaultHost, cmd.HostFlagDesc, false) + osflag.IntVar(&c.port, cmd.PortEnv, cmd.PortFlag, cmd.DefaultPort, cmd.HostFlagDesc, false) + osflag.StringVar(&c.emailAddr, cmd.EmailAddrEnv, cmd.EmailAddrFlag, cmd.DefaultEmailAddr, cmd.EmailAddrFlagDesc, true) + osflag.StringVar(&c.emailUser, cmd.EmailUserEnv, cmd.EmailUserFlag, cmd.DefaultEmailUser, cmd.EmailUserFlagDesc, true) + osflag.StringVar(&c.emailPass, cmd.EmailPassEnv, cmd.EmailPassFlag, cmd.DefaultEmailPass, cmd.EmailPassFlagDesc, true) + osflag.StringVar(&c.emailHost, cmd.EmailHostEnv, cmd.EmailHostFlag, cmd.DefaultEmailHost, cmd.EmailHostFlagDesc, true) + osflag.IntVar(&c.emailPort, cmd.EmailPortEnv, cmd.EmailPortFlag, cmd.DefaultEmailPort, cmd.EmailPortFlagDesc, false) + osflag.StringVar(&c.secret, cmd.SecretEnv, cmd.SecretFlag, cmd.DefaultSecret, cmd.SecretFlagDesc, true) + if err := osflag.Parse(nil); err != nil { + log.Fatalln("ERR: error parsing flags:", err) + } + if !osflag.Parsed() { + log.Fatalln("ERR: error parsing flags:", "flags not parsed") + osflag.PrintDefaults() + } + log.Println("INF: starting service with config:", c.String()) + // create email queue + emailQueue, err := email.NewEmailQueue(context.Background(), &email.EmailConfig{ + FromName: "SimpleAuthLink", + FromAddress: c.emailAddr, + SMTPUsername: c.emailUser, + SMTPPassword: c.emailPass, + SMTPServer: c.emailHost, + SMTPPort: c.emailPort, + }) if err != nil { - log.Fatalln("ERR: error parsing config:", err) - } - // init the database with mongo driver - db := new(mongo.MongoDriver) - if err := db.Init(mongo.Config{ - MongoURI: c.dbURI, - Database: c.dbName, - }); err != nil { - log.Fatalf("error initializing db: %v", err) + log.Fatalln("WRN: something occurs during email queue creation:", err) } + // start the email queue and defer to stop it + emailQueue.Start() + defer emailQueue.Stop() // create the service - service, err := api.New(context.Background(), db, &api.Config{ - EmailConfig: email.EmailConfig{ - Address: c.emailAddr, - Password: c.emailPass, - EmailHost: c.emailHost, - EmailPort: c.emailPort, - DisposableSrc: c.disposableSrc, - TokenEmailTemplate: c.tokenEmailTemplate, - AppEmailTemplate: c.appEmailTemplate, - }, - Server: c.host, - ServerPort: c.port, - CleanerCooldown: 30 * time.Minute, - }) + service, err := api.New(context.Background(), &api.Config{ + Server: c.host, + ServerPort: c.port, + Secret: c.secret, + }, emailQueue) if err != nil { log.Fatalln("ERR: error creating service:", err) } + // start the service in background go func() { if err := service.Start(); err != nil { log.Fatalln("ERR: error running service:", err) } }() // wait for the service to finish - service.WaitToShutdown() -} - -func parseConfig() (*config, error) { - var fhost, fdbURI, fdbName, femailAddr, femailPass, femailHost, ftokenEmailTemplate, fappEmailTemplate, fdisposableSrc string - var fport, femailPort int - // get config from flags - flag.StringVar(&fhost, hostFlag, defaultHost, hostFlagDesc) - flag.IntVar(&fport, portFlag, defaultPort, hostFlagDesc) - flag.StringVar(&fdbURI, dbURIFlag, defaultDatabaseURI, dbURIFlagDesc) - flag.StringVar(&fdbName, dbNameFlag, defaultDatabaseName, dbNameFlagDesc) - flag.StringVar(&femailAddr, emailAddrFlag, defaultEmailAddr, emailAddrFlagDesc) - flag.StringVar(&femailPass, emailPassFlag, defaultEmailPass, emailPassFlagDesc) - flag.StringVar(&femailHost, emailHostFlag, defaultEmailHost, emailHostFlagDesc) - flag.StringVar(&ftokenEmailTemplate, tokenEmailTemplateFlag, defaultTokenEmailTemplate, tokenEmailTemplateDesc) - flag.StringVar(&fappEmailTemplate, appEmailTemplateFlag, defaultAppEmailTemplate, appEmailTemplateDesc) - flag.IntVar(&femailPort, emailPortFlag, defaultEmailPort, emailPortFlagDesc) - flag.StringVar(&fdisposableSrc, disposableSrcFlag, defaultDisposableSrcURL, disposableSrcDesc) - flag.Parse() - // get config from env - envHost := os.Getenv(hostEnv) - envPort := os.Getenv(portEnv) - envDBURI := os.Getenv(dbURIEnv) - envDBName := os.Getenv(dbNameEnv) - envEmailAddr := os.Getenv(emailAddrEnv) - envEmailPass := os.Getenv(emailPassEnv) - envEmailHost := os.Getenv(emailHostEnv) - envEmailPort := os.Getenv(emailPortEnv) - envtokenEmailTemplate := os.Getenv(tokenEmailTemplateEnv) - envAppEmailTemplate := os.Getenv(appEmailTemplateEnv) - envDisposableSrc := os.Getenv(disposableSrcEnv) - - // check if the required flags are set - if femailAddr == "" && envEmailAddr == "" { - return nil, fmt.Errorf("email address is required, use -%s or set %s env var", emailAddrFlag, emailAddrEnv) - } - if femailPass == "" && envEmailPass == "" { - return nil, fmt.Errorf("email password is required, use -%s or set %s env var", emailPassFlag, emailPassEnv) - } - if femailHost == "" && envEmailHost == "" { - return nil, fmt.Errorf("email host is required, use -%s or set %s env var", emailHostFlag, emailHostEnv) - } - // set flags values by default - c := &config{ - host: fhost, - port: fport, - dbURI: fdbURI, - dbName: fdbName, - emailAddr: femailAddr, - emailPass: femailPass, - emailHost: femailHost, - emailPort: femailPort, - tokenEmailTemplate: ftokenEmailTemplate, - appEmailTemplate: fappEmailTemplate, - disposableSrc: fdisposableSrc, - } - // if some flags are not set, set them by env - if envHost != "" { - c.host = envHost - } - if envPort != "" { - if nenvPort, err := strconv.Atoi(envPort); err == nil { - c.port = nenvPort - } else { - return nil, fmt.Errorf("invalid port value: %s", envPort) - } - } - if envDBURI != "" { - c.dbURI = envDBURI - } - if envDBName != "" { - c.dbName = envDBName - } - if envEmailAddr != "" { - c.emailAddr = envEmailAddr - } - if envEmailPass != "" { - c.emailPass = envEmailPass - } - if envEmailHost != "" { - c.emailHost = envEmailHost - } - if envEmailPort != "" { - if nenvEmailPort, err := strconv.Atoi(envEmailPort); err == nil { - c.emailPort = nenvEmailPort - } else { - return nil, fmt.Errorf("invalid email port value: %s", envEmailPort) - } - } - if envtokenEmailTemplate != "" { - c.tokenEmailTemplate = envtokenEmailTemplate - } - if envAppEmailTemplate != "" { - c.appEmailTemplate = envAppEmailTemplate - } - if envDisposableSrc != "" { - c.disposableSrc = envDisposableSrc + if err := service.WaitToShutdown(); err != nil { + log.Fatalln("ERR: error waiting for service to finish:", err) } - return c, nil } diff --git a/cmd/consts.go b/cmd/consts.go new file mode 100644 index 0000000..0a5bb9e --- /dev/null +++ b/cmd/consts.go @@ -0,0 +1,38 @@ +package cmd + +const ( + DefaultHost = "0.0.0.0" + DefaultPort = 8080 + DefaultEmailAddr = "" + DefaultEmailUser = "" + DefaultEmailPass = "" + DefaultEmailHost = "" + DefaultEmailPort = 587 + DefaultSecret = "simpleauthlink-secret" + + HostFlag = "host" + PortFlag = "port" + EmailAddrFlag = "email-addr" + EmailUserFlag = "email-user" + EmailPassFlag = "email-pass" + EmailHostFlag = "email-host" + EmailPortFlag = "email-port" + SecretFlag = "secret" + HostFlagDesc = "service host" + PortFlagDesc = "service port" + EmailAddrFlagDesc = "email account address" + EmailUserFlagDesc = "email account username" + EmailPassFlagDesc = "email account password" + EmailHostFlagDesc = "email server host" + EmailPortFlagDesc = "email server port" + SecretFlagDesc = "secret used to generate the tokens" + + HostEnv = "HOST" + PortEnv = "PORT" + EmailAddrEnv = "EMAIL_ADDR" + EmailUserEnv = "EMAIL_USER" + EmailPassEnv = "EMAIL_PASS" + EmailHostEnv = "EMAIL_HOST" + EmailPortEnv = "EMAIL_PORT" + SecretEnv = "SECRET" +) diff --git a/cmd/demo/main.go b/cmd/demo/main.go new file mode 100644 index 0000000..ea4a4aa --- /dev/null +++ b/cmd/demo/main.go @@ -0,0 +1,63 @@ +package main + +import ( + "context" + "log" + + "github.com/simpleauthlink/authapi/api" + "github.com/simpleauthlink/authapi/cmd" + "github.com/simpleauthlink/authapi/internal/osflag" + "github.com/simpleauthlink/authapi/notification/email" +) + +func main() { + var ( + demoServer string + demoPort int + demoSecret string + ) + osflag.StringVar(&demoServer, cmd.HostEnv, cmd.HostFlag, cmd.DefaultHost, cmd.HostFlagDesc, false) + osflag.IntVar(&demoPort, cmd.PortEnv, cmd.PortFlag, cmd.DefaultPort, cmd.PortFlagDesc, false) + osflag.StringVar(&demoSecret, cmd.SecretEnv, cmd.SecretFlag, cmd.DefaultSecret, cmd.SecretFlagDesc, false) + if err := osflag.Parse(nil); err != nil { + log.Fatalln("ERR: error parsing flags:", err) + } + log.Println("INF: starting service with config:", demoServer, demoPort, demoSecret) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // create the email queue + emailQueue, err := email.NewEmailQueue(context.Background(), &email.EmailConfig{ + FromName: "SimpleAuthLink Demo", + FromAddress: "demo@simpleauth.link", + SMTPServer: demoServer, + SMTPPort: 2525, + }) + if err != nil { + log.Fatalln("WRN: something occurs during email queue creation:", err) + } + // start the email queue and defer to stop it + emailQueue.Start() + defer emailQueue.Stop() + // create the service + service, err := api.New(ctx, &api.Config{ + Server: demoServer, + ServerPort: demoPort, + Secret: demoSecret, + DemoMode: true, + DemoSMTPAddr: demoServer, + DemoSMTPPort: 2525, + }, emailQueue) + if err != nil { + log.Fatalln("ERR: error creating service:", err) + } + // start the service in background + go func() { + if err := service.Start(); err != nil { + log.Fatalln("ERR: error running service:", err) + } + }() + // wait for the service to finish + if err := service.WaitToShutdown(); err != nil { + log.Fatalln("ERR: error waiting for service to finish:", err) + } +} diff --git a/db/db.go b/db/db.go deleted file mode 100644 index 9f8cec9..0000000 --- a/db/db.go +++ /dev/null @@ -1,111 +0,0 @@ -package db - -import ( - "fmt" - "time" -) - -var ( - // ErrInvalidConfig error is returned when the provided database - // configuration is missing or invalid. - ErrInvalidConfig = fmt.Errorf("invalid database config") - // ErrOpenConn error is returned when the database connection can't be - // opened with the provided configuration. - ErrOpenConn = fmt.Errorf("error opening database") - // ErrCloseConn error is returned when the database connection can't be - // closed. - ErrCloseConn = fmt.Errorf("error closing database") - // ErrAppNotFound error is returned when the desired app is not found in the - // database. - ErrAppNotFound = fmt.Errorf("app not found") - // ErrGetApp error is returned when something fails getting a app from the - // database. - ErrGetApp = fmt.Errorf("error getting the app from database") - // ErrSetApp error is returned when something fails storing a app in the - // database. - ErrSetApp = fmt.Errorf("error storing the app in database") - // ErrDelApp error is returned when something fails deleting a app from the - // database. - ErrDelApp = fmt.Errorf("error deleting the app from database") - // ErrSecretNotFound error is returned when the desired secret is not found - // in the database. - ErrSetSecret = fmt.Errorf("error storing the secret in database") - // ErrDelSecret error is returned when something fails deleting a secret - // from the database. - ErrDelSecret = fmt.Errorf("error deleting the secret from database") - // ErrTokenNotFound error is returned when the desired token is not found in - // the database. - ErrTokenNotFound = fmt.Errorf("token not found") - // ErrGetToken error is returned when something fails getting a token from - // the database. - ErrGetToken = fmt.Errorf("error getting the token from database") - // ErrSetToken error is returned when something fails storing a token in the - // database. - ErrSetToken = fmt.Errorf("error storing the token in database") - // ErrDelToken error is returned when something fails deleting a token from - // the database. - ErrDelToken = fmt.Errorf("error deleting the token from database") -) - -// App struct represents the application information that is stored in the -// database. -type App struct { - Name string - AdminEmail string - SessionDuration uint64 - RedirectURL string - UsersQuota int64 -} - -// Token type represents the token that is stored in the database. -type Token string - -type DB interface { - // Init method allows to the interface implementation to receive some config - // information and init the database connection. It returns an error if the - // config is invalid or the connection can't be opened. - Init(config any) error - // Close method allows to the interface implementation to close the database - // connection. It returns an error if something fails during the closing. - Close() error - // AppById method gets an app from the database based on the app id. It - // returns the app and an error if something goes wrong. - AppById(appId string) (*App, error) - // AppBySecret method gets an app from the database based on the app secret. - // It returns the app, the app id and an error if something goes wrong. - AppBySecret(secret string) (*App, string, error) - // SetApp method stores an app in the database. It returns an error if - // something goes wrong. - SetApp(appId string, app *App) error - // DeleteApp method deletes an app from the database. It returns an error if - // something goes wrong. - DeleteApp(appId string) error - // ValidSecret method checks if a secret is valid. It returns true if the - // secret is valid and false if it is not. - ValidSecret(secret, appId string) (bool, error) - // SetSecret method stores a secret in the database. It returns an error if - // something goes wrong. - SetSecret(secret, appId string) error - // DeleteSecret method deletes a secret from the database. It returns an - // error if something goes wrong. - DeleteSecret(secret string) error - // TokenExpiration method gets the token expiration from the database. It - // returns the expiration time and an error if something goes wrong. - TokenExpiration(token Token) (time.Time, error) - // SetToken method stores a token in the database with an expiration time. - // It returns an error if something goes wrong. - SetToken(token Token, expiration time.Time) error - // DeleteToken method deletes a token from the database. It returns an error - // if something goes wrong. - DeleteToken(token Token) error - // DeleteTokenByPrefix method deletes all the tokens with the provided - // prefix from the database. It returns an error if something goes wrong. - DeleteTokensByPrefix(prefix string) error - // DeleteExpiredTokens method deletes all the expired tokens from the - // database. It returns an error if something goes wrong. - DeleteExpiredTokens() error - // CountTokens method counts the number of tokens in the database. It allows - // to filter the tokens by the provided prefix. It returns the number of - // tokens and an error if something goes wrong. - CountTokens(prefix string) (int64, error) -} diff --git a/db/mongo/apps.go b/db/mongo/apps.go deleted file mode 100644 index f29ce67..0000000 --- a/db/mongo/apps.go +++ /dev/null @@ -1,149 +0,0 @@ -package mongo - -import ( - "context" - "errors" - "time" - - "github.com/simpleauthlink/authapi/db" - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/mongo" - "go.mongodb.org/mongo-driver/mongo/options" -) - -type App struct { - ID string `bson:"_id"` - Name string `bson:"name"` - AdminEmail string `bson:"admin_email"` - SessionDuration uint64 `bson:"session_duration"` - RedirectURL string `bson:"redirect_url"` - UsersQuota int64 `bson:"users_quota"` - Secret string `bson:"secret"` -} - -func (md *MongoDriver) AppById(appId string) (*db.App, error) { - ctx, cancel := context.WithTimeout(md.ctx, 5*time.Second) - defer cancel() - // get app from the database based on the app id - var app App - if err := md.apps.FindOne(ctx, bson.M{"_id": appId}).Decode(&app); err != nil { - if err == mongo.ErrNoDocuments { - return nil, db.ErrAppNotFound - } - return nil, errors.Join(db.ErrGetApp, err) - } - // return app - return &db.App{ - Name: app.Name, - AdminEmail: app.AdminEmail, - SessionDuration: app.SessionDuration, - RedirectURL: app.RedirectURL, - UsersQuota: app.UsersQuota, - }, nil -} - -func (md *MongoDriver) AppBySecret(secret string) (*db.App, string, error) { - ctx, cancel := context.WithTimeout(md.ctx, 5*time.Second) - defer cancel() - // get app from the database based on the app secret - var app App - if err := md.apps.FindOne(ctx, bson.M{"secret": secret}).Decode(&app); err != nil { - if err == mongo.ErrNoDocuments { - return nil, "", db.ErrAppNotFound - } - return nil, "", errors.Join(db.ErrGetApp, err) - } - // return app and app id - return &db.App{ - Name: app.Name, - AdminEmail: app.AdminEmail, - SessionDuration: app.SessionDuration, - RedirectURL: app.RedirectURL, - UsersQuota: app.UsersQuota, - }, app.ID, nil -} - -func (md *MongoDriver) SetApp(appId string, app *db.App) error { - md.keysLock.Lock() - defer md.keysLock.Unlock() - // create or update app in the database - ctx, cancel := context.WithTimeout(md.ctx, 5*time.Second) - defer cancel() - dbApp, err := dynamicUpdateDocument(App{ - ID: appId, - Name: app.Name, - AdminEmail: app.AdminEmail, - SessionDuration: app.SessionDuration, - RedirectURL: app.RedirectURL, - UsersQuota: app.UsersQuota, - }, nil) - if err != nil { - return errors.Join(db.ErrSetApp, err) - } - opts := options.Update().SetUpsert(true) - if _, err := md.apps.UpdateOne(ctx, bson.M{"_id": appId}, dbApp, opts); err != nil { - return errors.Join(db.ErrSetApp, err) - } - return nil -} - -func (md *MongoDriver) DeleteApp(appId string) error { - md.keysLock.Lock() - defer md.keysLock.Unlock() - // delete secret from the database by the app id - ctx, cancel := context.WithTimeout(md.ctx, 5*time.Second) - defer cancel() - if _, err := md.apps.DeleteOne(ctx, bson.M{"_id": appId}); err != nil { - if err == mongo.ErrNoDocuments { - return db.ErrAppNotFound - } - return errors.Join(db.ErrDelApp, err) - } - return nil -} - -func (md *MongoDriver) ValidSecret(secret, appId string) (bool, error) { - md.keysLock.Lock() - defer md.keysLock.Unlock() - // get app from the database based on the app id - ctx, cancel := context.WithTimeout(md.ctx, 5*time.Second) - defer cancel() - var app App - if err := md.apps.FindOne(ctx, bson.M{"_id": appId}).Decode(&app); err != nil { - if err == mongo.ErrNoDocuments { - return false, db.ErrAppNotFound - } - return false, errors.Join(db.ErrGetApp, err) - } - return app.Secret == secret, nil -} - -func (md *MongoDriver) SetSecret(secret, appId string) error { - md.keysLock.Lock() - defer md.keysLock.Unlock() - // set secret to app in the database by the app id - ctx, cancel := context.WithTimeout(md.ctx, 5*time.Second) - defer cancel() - if _, err := md.apps.UpdateOne(ctx, bson.M{"_id": appId}, bson.M{"$set": bson.M{"secret": secret}}); err != nil { - if err == mongo.ErrNoDocuments { - return db.ErrAppNotFound - } - return errors.Join(db.ErrSetSecret, err) - } - return nil -} - -func (md *MongoDriver) DeleteSecret(secret string) error { - md.keysLock.Lock() - defer md.keysLock.Unlock() - // delete secret of the app from the database - ctx, cancel := context.WithTimeout(md.ctx, 5*time.Second) - defer cancel() - if _, err := md.apps.UpdateOne(ctx, bson.M{"secret": secret}, bson.M{"$unset": bson.M{"secret": ""}}); err != nil { - if err == mongo.ErrNoDocuments { - return db.ErrAppNotFound - } - return errors.Join(db.ErrDelSecret, err) - } - return nil -} diff --git a/db/mongo/mongo.go b/db/mongo/mongo.go deleted file mode 100644 index 2e8c183..0000000 --- a/db/mongo/mongo.go +++ /dev/null @@ -1,161 +0,0 @@ -package mongo - -import ( - "context" - "errors" - "fmt" - "reflect" - "sync" - "time" - - "github.com/simpleauthlink/authapi/db" - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/mongo" - "go.mongodb.org/mongo-driver/mongo/options" - "go.mongodb.org/mongo-driver/mongo/readpref" -) - -const ( - tokensCollection = "tokens" - secretsCollection = "secrets" - appsCollection = "apps" -) - -type Config struct { - MongoURI string - Database string -} - -type MongoDriver struct { - ctx context.Context - cancel context.CancelFunc - config Config - client *mongo.Client - keysLock sync.RWMutex - - tokens *mongo.Collection - apps *mongo.Collection -} - -func (md *MongoDriver) Init(config any) error { - // validate config - cfg, ok := config.(Config) - if !ok { - return db.ErrInvalidConfig - } - if cfg.Database == "" { - return fmt.Errorf("%w: no database name provided", db.ErrInvalidConfig) - } - if cfg.MongoURI == "" { - return fmt.Errorf("%w: no database url provided", db.ErrInvalidConfig) - } - // init the client options - opts := options.Client() - opts.ApplyURI(cfg.MongoURI) - opts.SetMaxConnecting(200) - timeout := time.Second * 10 - opts.ConnectTimeout = &timeout - // connect to the database - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - client, err := mongo.Connect(ctx, opts) - if err != nil { - return errors.Join(db.ErrOpenConn, err) - } - // check if the connection is available - ctx, cancel2 := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel2() - if err := client.Ping(ctx, readpref.Primary()); err != nil { - return errors.Join(db.ErrOpenConn, err) - } - // create the internal context - md.ctx, md.cancel = context.WithCancel(context.Background()) - // set the client and config - md.client = client - md.config = cfg - // instantiate the collections - md.tokens = client.Database(cfg.Database).Collection(tokensCollection) - md.apps = client.Database(cfg.Database).Collection(appsCollection) - // create the indexes - if err := md.createIndexes(); err != nil { - return errors.Join(db.ErrOpenConn, err) - } - return nil -} - -func (md *MongoDriver) Close() error { - md.cancel() - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - if err := md.client.Disconnect(ctx); err != nil { - return errors.Join(db.ErrCloseConn, err) - } - return nil -} - -// createIndexes creates the indexes for the collections. It creates an index -// for the app secrets and an index for the token expiration. It returns an -// error if something goes wrong. -func (md *MongoDriver) createIndexes() error { - ctx, cancel := context.WithTimeout(md.ctx, 20*time.Second) - defer cancel() - // create an index for app secrets - if _, err := md.apps.Indexes().CreateOne(ctx, mongo.IndexModel{ - Keys: bson.D{{Key: "secrets", Value: 1}}, // 1 for ascending order - Options: nil, - }); err != nil { - return err - } - // create an index for token expiration - if _, err := md.tokens.Indexes().CreateOne(ctx, mongo.IndexModel{ - Keys: bson.D{{Key: "expiration", Value: 1}}, - Options: nil, - }); err != nil { - return err - } - return nil -} - -// dynamicUpdateDocument creates a BSON update document from a struct, -// including only non-zero fields. It uses reflection to iterate over the -// struct fields and create the update document. The struct fields must have -// a bson tag to be included in the update document. The _id field is skipped. -func dynamicUpdateDocument(item interface{}, alwaysUpdate []string) (bson.M, error) { - // check if the input is a pointer to a struct - val := reflect.ValueOf(item) - if val.Kind() == reflect.Ptr { - val = val.Elem() - } - // check if the input is a struct - if !val.IsValid() || val.Kind() != reflect.Struct { - return nil, fmt.Errorf("input must be a valid struct") - } - update := bson.M{} - typ := val.Type() - // create a map for quick lookup of always update fields - alwaysUpdateMap := make(map[string]bool, len(alwaysUpdate)) - for _, tag := range alwaysUpdate { - alwaysUpdateMap[tag] = true - } - // iterate over the struct fields - for i := 0; i < val.NumField(); i++ { - // check if the field can be accessed - field := val.Field(i) - if !field.CanInterface() { - continue - } - // get the field bson tag and type - fieldType := typ.Field(i) - tag := fieldType.Tag.Get("bson") - // skip the field if the tag is empty, "-" or "_id" - if tag == "" || tag == "-" || tag == "_id" { - continue - } - // check if the field should always be updated or is not the zero value - _, alwaysUpdate := alwaysUpdateMap[tag] - if alwaysUpdate || !reflect.DeepEqual(field.Interface(), reflect.Zero(field.Type()).Interface()) { - update[tag] = field.Interface() - } - } - return bson.M{"$set": update}, nil -} diff --git a/db/mongo/tokens.go b/db/mongo/tokens.go deleted file mode 100644 index 1afe378..0000000 --- a/db/mongo/tokens.go +++ /dev/null @@ -1,113 +0,0 @@ -package mongo - -import ( - "context" - "errors" - "time" - - "github.com/simpleauthlink/authapi/db" - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/mongo" - "go.mongodb.org/mongo-driver/mongo/options" -) - -type Token struct { - Token db.Token `bson:"_id"` - Expiration int64 `bson:"expiration"` -} - -func (md *MongoDriver) TokenExpiration(token db.Token) (time.Time, error) { - var dbToken Token - ctx, cancel := context.WithTimeout(md.ctx, 5*time.Second) - defer cancel() - if err := md.tokens.FindOne(ctx, bson.M{"_id": token}).Decode(&dbToken); err != nil { - if err == mongo.ErrNoDocuments { - return time.Time{}, db.ErrTokenNotFound - } - return time.Time{}, errors.Join(db.ErrGetToken, err) - } - return time.Unix(0, dbToken.Expiration), nil -} - -func (md *MongoDriver) SetToken(token db.Token, expiration time.Time) error { - md.keysLock.Lock() - defer md.keysLock.Unlock() - // set token in the database - ctx, cancel := context.WithTimeout(md.ctx, 5*time.Second) - defer cancel() - dbToken := Token{ - Token: token, - Expiration: expiration.UnixNano(), - } - opts := options.Replace().SetUpsert(true) - if _, err := md.tokens.ReplaceOne(ctx, bson.M{"_id": token}, dbToken, opts); err != nil { - return errors.Join(db.ErrSetToken, err) - } - return nil -} - -func (md *MongoDriver) DeleteToken(token db.Token) error { - md.keysLock.Lock() - defer md.keysLock.Unlock() - // delete token from the database - ctx, cancel := context.WithTimeout(md.ctx, 5*time.Second) - defer cancel() - if _, err := md.tokens.DeleteOne(ctx, bson.M{"_id": token}); err != nil { - if err == mongo.ErrNoDocuments { - return db.ErrTokenNotFound - } - return errors.Join(db.ErrDelToken, err) - } - return nil -} - -func (md *MongoDriver) DeleteTokensByPrefix(prefix string) error { - // check if the prefix is empty and return nil if it is - if prefix == "" { - return nil - } - // check if there is a token with the provided prefix in the database - ctx, cancel := context.WithTimeout(md.ctx, 5*time.Second) - defer cancel() - if _, err := md.tokens.DeleteMany(ctx, bson.M{"_id": bson.M{"$regex": "^" + prefix}}); err != nil { - if err == mongo.ErrNoDocuments { - return db.ErrTokenNotFound - } - return errors.Join(db.ErrGetToken, err) - } - return nil -} - -func (md *MongoDriver) DeleteExpiredTokens() error { - md.keysLock.Lock() - defer md.keysLock.Unlock() - // delete expired tokens from the database, filter by expiration time less - // than now - ctx, cancel := context.WithTimeout(md.ctx, 5*time.Second) - defer cancel() - dbNow := time.Now().UnixNano() - if _, err := md.tokens.DeleteMany(ctx, bson.M{"expiration": bson.M{"$lt": dbNow}}); err != nil { - return errors.Join(db.ErrDelToken, err) - } - return nil -} - -func (md *MongoDriver) CountTokens(prefix string) (int64, error) { - // count the number of tokens in the database, filter by the provided prefix - ctx, cancel := context.WithTimeout(md.ctx, 5*time.Second) - defer cancel() - // filter by prefix if provided - filter := bson.M{} - if prefix != "" { - filter = bson.M{"_id": bson.M{"$regex": "^" + prefix}} - } - // count the number of tokens and return the result - count, err := md.tokens.CountDocuments(ctx, filter) - if err != nil { - if err == mongo.ErrNoDocuments { - return 0, db.ErrTokenNotFound - } - return 0, errors.Join(db.ErrGetToken, err) - } - return count, nil -} diff --git a/db/temp.go b/db/temp.go deleted file mode 100644 index 760b4a8..0000000 --- a/db/temp.go +++ /dev/null @@ -1,152 +0,0 @@ -package db - -import ( - "strings" - "sync" - "time" -) - -type TempDriver struct { - apps map[string]App - secretToApp map[string]string - tokens map[Token]int64 - lock sync.RWMutex -} - -func (tdb *TempDriver) Init(_ any) error { - tdb.apps = make(map[string]App) - tdb.secretToApp = make(map[string]string) - tdb.tokens = make(map[Token]int64) - return nil -} - -func (tdb *TempDriver) Close() error { - return nil -} - -func (tdb *TempDriver) AppById(appId string) (*App, error) { - tdb.lock.RLock() - defer tdb.lock.RUnlock() - app, ok := tdb.apps[appId] - if !ok { - return nil, ErrAppNotFound - } - return &app, nil -} - -func (tdb *TempDriver) AppBySecret(secret string) (*App, string, error) { - tdb.lock.RLock() - defer tdb.lock.RUnlock() - appId, ok := tdb.secretToApp[secret] - if !ok { - return nil, "", ErrAppNotFound - } - app, ok := tdb.apps[appId] - if !ok { - return nil, "", ErrAppNotFound - } - return &app, appId, nil -} - -func (tdb *TempDriver) SetApp(appId string, app *App) error { - tdb.lock.Lock() - defer tdb.lock.Unlock() - tdb.apps[appId] = *app - return nil -} - -func (tdb *TempDriver) DeleteApp(appId string) error { - tdb.lock.Lock() - defer tdb.lock.Unlock() - delete(tdb.apps, appId) - return nil -} - -func (tdb *TempDriver) ValidSecret(secret, appId string) (bool, error) { - tdb.lock.RLock() - defer tdb.lock.RUnlock() - appIdFromSecret, ok := tdb.secretToApp[secret] - if !ok { - return false, nil - } - return appIdFromSecret == appId, nil -} - -func (tdb *TempDriver) SetSecret(secret, appId string) error { - tdb.lock.Lock() - defer tdb.lock.Unlock() - tdb.secretToApp[secret] = appId - return nil -} - -func (tdb *TempDriver) DeleteSecret(secret string) error { - tdb.lock.Lock() - defer tdb.lock.Unlock() - delete(tdb.secretToApp, secret) - return nil -} - -func (tdb *TempDriver) TokenExpiration(token Token) (time.Time, error) { - tdb.lock.RLock() - defer tdb.lock.RUnlock() - exp, ok := tdb.tokens[token] - if !ok { - return time.Time{}, ErrTokenNotFound - } - return time.Unix(0, exp), nil -} - -func (tdb *TempDriver) SetToken(token Token, expiration time.Time) error { - tdb.lock.Lock() - defer tdb.lock.Unlock() - tdb.tokens[token] = expiration.UnixNano() - return nil -} - -func (tdb *TempDriver) DeleteToken(token Token) error { - tdb.lock.Lock() - defer tdb.lock.Unlock() - delete(tdb.tokens, token) - return nil -} - -func (tdb *TempDriver) DeleteTokensByPrefix(prefix string) error { - tdb.lock.Lock() - defer tdb.lock.Unlock() - if prefix == "" { - return nil - } - for token := range tdb.tokens { - if strings.HasPrefix(string(token), prefix) { - delete(tdb.tokens, token) - } - } - return nil -} - -func (tdb *TempDriver) DeleteExpiredTokens() error { - tdb.lock.Lock() - defer tdb.lock.Unlock() - now := time.Now().UnixNano() - for token, expiration := range tdb.tokens { - if now > expiration { - delete(tdb.tokens, token) - } - } - return nil -} - -func (tdb *TempDriver) CountTokens(prefix string) (int64, error) { - tdb.lock.RLock() - defer tdb.lock.RUnlock() - if prefix == "" { - return int64(len(tdb.tokens)), nil - } - var count int64 - for token := range tdb.tokens { - if strings.HasPrefix(string(token), prefix) { - count++ - } - } - return count, nil -} diff --git a/docker-compose.yml b/docker-compose.yml deleted file mode 100644 index d85f3c9..0000000 --- a/docker-compose.yml +++ /dev/null @@ -1,39 +0,0 @@ -name: simpleauthlink - -services: - authapi: - env_file: - - .env - build: - context: ./ - ports: - - ${SIMPLEAUTH_PORT}:${SIMPLEAUTH_PORT} - sysctls: - net.core.somaxconn: 8128 - restart: ${RESTART:-unless-stopped} - depends_on: - - mongo - mongo: - image: mongo - restart: ${RESTART:-unless-stopped} - ports: - - 27017:27017 - environment: - - MONGO_INITDB_ROOT_USERNAME=root - - MONGO_INITDB_ROOT_PASSWORD=authapi - - MONGO_INITDB_DATABASE=simpleauth - volumes: - - mongodb:/data/mongodb - mongo-express: - image: mongo-express - restart: ${RESTART:-unless-stopped} - ports: - - 8081:8081 - environment: - ME_CONFIG_MONGODB_ADMINUSERNAME: root - ME_CONFIG_MONGODB_ADMINPASSWORD: authapi - ME_CONFIG_MONGODB_URL: mongodb://root:authapi@mongo:27017/ -volumes: - mongodb: {} - - diff --git a/docker/Dockerfile.demo b/docker/Dockerfile.demo new file mode 100644 index 0000000..38694fb --- /dev/null +++ b/docker/Dockerfile.demo @@ -0,0 +1,18 @@ +# build +FROM golang:1.24-alpine as builder + +WORKDIR /app/data +COPY . . + +RUN go mod tidy +RUN go build -o /authapi ./cmd/demo/main.go + +# deploy +FROM alpine:latest + +EXPOSE 80 + +WORKDIR / +COPY --from=builder /authapi /authapi + +ENTRYPOINT /authapi \ No newline at end of file diff --git a/Dockerfile b/docker/Dockerfile.prod similarity index 71% rename from Dockerfile rename to docker/Dockerfile.prod index 960ed89..dc43de2 100644 --- a/Dockerfile +++ b/docker/Dockerfile.prod @@ -1,5 +1,5 @@ # build -FROM golang:1.21-alpine as builder +FROM golang:1.24-alpine as builder WORKDIR /app/data COPY . . @@ -10,8 +10,9 @@ RUN go build -o /authapi ./cmd/authapi/main.go # deploy FROM alpine:latest +EXPOSE 80 + WORKDIR / COPY --from=builder /authapi /authapi -COPY --from=builder /app/data/assets /assets ENTRYPOINT /authapi \ No newline at end of file diff --git a/email/disposable.go b/email/disposable.go deleted file mode 100644 index ab62b07..0000000 --- a/email/disposable.go +++ /dev/null @@ -1,68 +0,0 @@ -package email - -import ( - "bufio" - "context" - "errors" - "net/http" - "regexp" - "strings" - "time" -) - -// domainRgx is the regular expression used to validate a domain. -var domainRgx = regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`) - -// LoadRemoteDisposableDomains loads a list of disposable domains from a remote -// source url. It reads the content of the source url line by line and parses -// each line as a domain. It returns a list of disposable domains or an error if -// something fails. -func LoadRemoteDisposableDomains(ctx context.Context, disposableSrc string) ([]string, error) { - internalCtx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - // prepare the request - req, err := http.NewRequestWithContext(internalCtx, http.MethodGet, disposableSrc, nil) - if err != nil { - return nil, errors.Join(ErrLoadingDisposableDomains, err) - } - // perform the request - resp, err := http.DefaultClient.Do(req) - if err != nil { - return nil, errors.Join(ErrLoadingDisposableDomains, err) - } - // read the response body line by line - defer resp.Body.Close() - scanner := bufio.NewScanner(resp.Body) - var domains []string - for scanner.Scan() { - domain := scanner.Text() - if domainRgx.MatchString(domain) { - domains = append(domains, domain) - } - } - if err := scanner.Err(); err != nil { - return nil, errors.Join(ErrLoadingDisposableDomains, err) - } - return domains, nil -} - -// CheckEmail checks if the email address is valid. It compares the domain with -// a list of disallowed domains. It returns true if the email address is valid, -// otherwise it returns false. -func CheckEmail(disallowedDomains []string, email string) bool { - if len(disallowedDomains) == 0 { - return true - } - // split the email address - parts := strings.Split(email, "@") - if len(parts) != 2 { - return false - } - // check the domain - for _, domain := range disallowedDomains { - if domain == parts[1] { - return false - } - } - return true -} diff --git a/email/emailqueue.go b/email/emailqueue.go deleted file mode 100644 index 9dac8ec..0000000 --- a/email/emailqueue.go +++ /dev/null @@ -1,224 +0,0 @@ -package email - -import ( - "bytes" - "context" - "fmt" - "net/mail" - "net/smtp" - "net/textproto" - "regexp" - "sync" - "time" -) - -// sendRetries is the number of retries to send the email. -const sendRetries = 3 - -// emailRgx is the regular expression used to validate an email address. -var emailRgx = regexp.MustCompile(`^[\w-\.]+@([\w-]+\.)+[\w-]{2,}$`) - -// EmailConfig struct represents the email configuration that is needed to send -// an email using and SMTP server. It includes the email address (used as the -// sender address but also as the username for the SMTP server), the email -// server hostname, its port and the password. -type EmailConfig struct { - Address string - EmailHost string - EmailPort int - Password string - DisposableSrc string - TokenEmailTemplate string - AppEmailTemplate string -} - -// Email struct represents the email that is going to be sent. It includes the -// recipient email address, the subject and the body of the email. -type Email struct { - To string - Subject string - Body string -} - -// EmailQueue struct represents the email queue. It includes the context and the -// cancel function to stop the queue, the configuration of the server to send -// the email, the list of emails to send, and the waiter to wait for the -// background process to finish. -type EmailQueue struct { - ctx context.Context - cancel context.CancelFunc - cfg *EmailConfig - items []*Email - itemsMtx sync.Mutex - waiter sync.WaitGroup - disallowedDomains []string -} - -// NewEmailQueue creates a new EmailQueue with the provided configuration. -func NewEmailQueue(ctx context.Context, cfg *EmailConfig) (*EmailQueue, error) { - // check if the configuration is valid - if cfg.Address == "" || !emailRgx.MatchString(cfg.Address) || - cfg.EmailHost == "" || cfg.EmailPort == 0 || cfg.Password == "" { - return nil, ErrInvalidConfig - } - internalCtx, cancel := context.WithCancel(ctx) - // load the disposable domains if a source is provided - var err error - disallowedDomains := []string{} - if cfg.DisposableSrc != "" { - disallowedDomains, err = LoadRemoteDisposableDomains(internalCtx, cfg.DisposableSrc) - } - // return the email queue - return &EmailQueue{ - ctx: internalCtx, - cancel: cancel, - cfg: cfg, - items: []*Email{}, - disallowedDomains: disallowedDomains, - }, err -} - -// Start method starts the email queue. It listens for new emails in the queue -// and sends them using the provided configuration. -func (eq *EmailQueue) Start() { - eq.waiter.Add(1) - go func() { - defer eq.waiter.Done() - for { - select { - case <-eq.ctx.Done(): - return - default: - e := eq.Pop() - if e == nil { - continue - } - if err := eq.Send(e); err != nil { - fmt.Println(err) - } else { - eq.Pop() - } - } - time.Sleep(time.Second) - } - }() -} - -func (eq *EmailQueue) Stop() { - eq.cancel() - eq.waiter.Wait() -} - -// Push method adds a new email to the queue. -func (eq *EmailQueue) Push(e *Email) error { - // check if the email is valid - if e.To == "" || !emailRgx.MatchString(e.To) || e.Subject == "" || e.Body == "" { - return ErrInvalidEmail - } - // check if the email is allowed - if !eq.Allowed(e.To) { - return ErrDisallowedDomain - } - eq.itemsMtx.Lock() - eq.items = append(eq.items, e) - eq.itemsMtx.Unlock() - return nil -} - -// Top method returns the first email in the queue. -func (eq *EmailQueue) Top() *Email { - eq.itemsMtx.Lock() - defer eq.itemsMtx.Unlock() - if len(eq.items) == 0 { - return nil - } - return eq.items[0] -} - -// Pop method removes the first email in the queue and returns it. -func (eq *EmailQueue) Pop() *Email { - eq.itemsMtx.Lock() - defer eq.itemsMtx.Unlock() - if len(eq.items) == 0 { - return nil - } - e := eq.items[0] - eq.items = eq.items[1:] - return e -} - -// Send method sends the email using the queue configuration. It uses the -// email address as the sender address and the username for the SMTP server. -// It composes the email message, creates the auth object with the email -// credentials, the server string with the host and the port, and the receipts. -// Finally, it sends the email. If something fails during the process, it -// returns an error. -func (eq *EmailQueue) Send(e *Email) error { - // compose the email body - body, err := eq.encodeEmail(e) - if err != nil { - return fmt.Errorf("error composing email: %w", err) - } - // check if the email is allowed - if !eq.Allowed(e.To) { - return ErrDisallowedDomain - } - // create the auth object with the email credentials - auth := smtp.PlainAuth("", eq.cfg.Address, eq.cfg.Password, eq.cfg.EmailHost) - // create the server string with the host and the port and the receipts - server := fmt.Sprintf("%s:%d", eq.cfg.EmailHost, eq.cfg.EmailPort) - receipts := []string{e.To} - // send the email - for i := 0; i < sendRetries; i++ { - if err = smtp.SendMail(server, auth, eq.cfg.Address, receipts, body); err == nil { - break - } - } - if err != nil { - return fmt.Errorf("error sending email: %w", err) - } - return nil -} - -// Allowed method checks if the email address is allowed. It compares the domain -// with a list of disallowed domains. It returns true if the email address is -// allowed, otherwise it returns false. -func (eq *EmailQueue) Allowed(address string) bool { - if !emailRgx.MatchString(address) { - return false - } - return CheckEmail(eq.disallowedDomains, address) -} - -// encodeEmail method encodes the email to a byte slice. It validates the from -// and to addresses, sets the headers for the html email, and writes the body. -// It returns the encoded email or an error if something fails during the -// process. -func (eq *EmailQueue) encodeEmail(email *Email) ([]byte, error) { - // validate from address - from, err := mail.ParseAddress(eq.cfg.Address) - if err != nil { - return nil, fmt.Errorf("error parsing address: %w", err) - } - // validate to address - to, err := mail.ParseAddress(email.To) - if err != nil { - return nil, fmt.Errorf("error parsing address: %w", err) - } - // set headers for html email - header := textproto.MIMEHeader{} - header.Set(textproto.CanonicalMIMEHeaderKey("from"), from.Address) - header.Set(textproto.CanonicalMIMEHeaderKey("to"), to.Address) - header.Set(textproto.CanonicalMIMEHeaderKey("content-type"), "text/html; charset=UTF-8") - header.Set(textproto.CanonicalMIMEHeaderKey("mime-version"), "1.0") - header.Set(textproto.CanonicalMIMEHeaderKey("subject"), email.Subject) - // init empty message - var buffer bytes.Buffer - // write header - for key, value := range header { - buffer.WriteString(fmt.Sprintf("%s: %s\r\n", key, value[0])) - } - // write body - buffer.WriteString(fmt.Sprintf("\r\n%s", email.Body)) - return buffer.Bytes(), nil -} diff --git a/email/errors.go b/email/errors.go deleted file mode 100644 index 33218a7..0000000 --- a/email/errors.go +++ /dev/null @@ -1,19 +0,0 @@ -package email - -import "fmt" - -var ( - // ErrInvalidConfig is the error returned when the configuration is invalid. - ErrInvalidConfig = fmt.Errorf("invalid configuration") - // ErrInitQueue is the error returned when the queue cannot be initialized. - ErrInitQueue = fmt.Errorf("error initializing the queue") - // ErrInvalidDomain is the error returned when the domain is invalid. - ErrInvalidDomain = fmt.Errorf("invalid domain") - // ErrLoadingDisposableDomains is the error returned when the disposable - // domains cannot be loaded. - ErrLoadingDisposableDomains = fmt.Errorf("error loading disposable domains") - // ErrDisallowedDomain is the error returned when the domain is disallowed. - ErrDisallowedDomain = fmt.Errorf("disallowed domain") - // ErrInvalidEmail is the error returned when the email is invalid. - ErrInvalidEmail = fmt.Errorf("invalid email") -) diff --git a/email/templates.go b/email/templates.go deleted file mode 100644 index c2aa470..0000000 --- a/email/templates.go +++ /dev/null @@ -1,72 +0,0 @@ -package email - -import ( - "bytes" - "fmt" - "strings" - "text/template" -) - -// UserEmailData struct includes the data required to fill the user email -// template. -type UserEmailData struct { - AppName string - EmailHandler string - MagicLink string - Token string -} - -// AppEmailData struct includes the data required to fill the app email -// template. -type AppEmailData struct { - AppID string - AppName string - RedirectURL string - Secret string - EmailHandler string -} - -// NewUserEmailData creates a new UserEmailData with the provided data. -func NewUserEmailData(appName, email, magicLink, token string) *UserEmailData { - return &UserEmailData{ - AppName: appName, - EmailHandler: emailHandler(email), - MagicLink: magicLink, - Token: token, - } -} - -// NewAppEmailData creates a new AppEmailData with the provided data. -func NewAppEmailData(appID, appName, redirectURL, secret, email string) *AppEmailData { - return &AppEmailData{ - AppID: appID, - AppName: appName, - RedirectURL: redirectURL, - Secret: secret, - EmailHandler: emailHandler(email), - } -} - -// ParseTemplate parses the template file provided with the data provided. It -// returns the parsed template as a string. If an error occurs, it returns the -// error. -func ParseTemplate(templatePath string, data interface{}) (string, error) { - // parse the template file provided - t, err := template.ParseFiles(templatePath) - if err != nil { - return "", err - } - // execute the template to fill it with the data provided - buf := new(bytes.Buffer) - if err := t.Execute(buf, data); err != nil { - return "", fmt.Errorf("error parsing template: %w", err) - } - return buf.String(), nil -} - -// emailHandler method extracts the email handler from the email address. It -// splits the email address by the "@" symbol and returns the first part. -func emailHandler(emailAddress string) string { - emailParts := strings.Split(emailAddress, "@") - return emailParts[0] -} diff --git a/example.env b/example.env index b37a676..b2e07b8 100644 --- a/example.env +++ b/example.env @@ -1,6 +1,8 @@ -SIMPLEAUTH_EMAIL_ADDR="" -SIMPLEAUTH_EMAIL_PASS="" -SIMPLEAUTH_EMAIL_HOST="" -SIMPLEAUTH_DB_URI="mongodb://root:authapi@mongo:27017/" -SIMPLEAUTH_DB_NAME="simpleauth" -SIMPLEAUTH_DISPOSABLE_SRC="https://raw.githubusercontent.com/disposable-email-domains/disposable-email-domains/master/disposable_email_blocklist.conf" \ No newline at end of file +HOST="localhost" +PORT=8080 +EMAIL_ADDR="test@test.com" +EMAIL_USER="test@test.com" +EMAIL_PASS="smtp_server_password" +EMAIL_HOST="smtp.example.com" +EMAIL_PORT=587 +SECRET="my_backend_secret" \ No newline at end of file diff --git a/go.mod b/go.mod index 973a808..940068e 100644 --- a/go.mod +++ b/go.mod @@ -1,22 +1,7 @@ module github.com/simpleauthlink/authapi -go 1.21 +go 1.24 -require ( - github.com/lucasmenendez/apihandler v0.0.7 - go.mongodb.org/mongo-driver v1.15.0 -) +require github.com/lucasmenendez/apihandler v0.0.8 -require ( - github.com/golang/snappy v0.0.1 // indirect - github.com/klauspost/compress v1.13.6 // indirect - github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe // indirect - github.com/xdg-go/pbkdf2 v1.0.0 // indirect - github.com/xdg-go/scram v1.1.2 // indirect - github.com/xdg-go/stringprep v1.0.4 // indirect - github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect - golang.org/x/crypto v0.17.0 // indirect - golang.org/x/sync v0.1.0 // indirect - golang.org/x/text v0.14.0 // indirect - golang.org/x/time v0.6.0 // indirect -) +require golang.org/x/time v0.11.0 // indirect diff --git a/go.sum b/go.sum index afad0e8..7e7d3a9 100644 --- a/go.sum +++ b/go.sum @@ -1,64 +1,4 @@ -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= -github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/google/go-cmp v0.5.2 h1:X2ev0eStA3AbceY54o37/0PQ/UWqKEiiO2dKL5OPaFM= -github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/klauspost/compress v1.13.6 h1:P76CopJELS0TiO2mebmnzgWaajssP/EszplttgQxcgc= -github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= -github.com/lucasmenendez/apihandler v0.0.4 h1:QspySW+hZp45HsLur2VcJQ/EcaRzll6XhOUriPRrHYs= -github.com/lucasmenendez/apihandler v0.0.4/go.mod h1:1R2dcf/Wbr6sx7Gjjv5oWWKgD8Pokib/d5BuCwhaBcA= -github.com/lucasmenendez/apihandler v0.0.5-0.20240520102504-ffd40a81622e h1:NpAmyHWwUwxtYGFiQRfHA1+BGnv05rQ9m/p9YK41yC0= -github.com/lucasmenendez/apihandler v0.0.5-0.20240520102504-ffd40a81622e/go.mod h1:gDwdzFu8GquIz0UkrA+UMjaYUQGtfDymm6i4iKEcM44= -github.com/lucasmenendez/apihandler v0.0.6 h1:og9FRFIiPwLAyLbwFS3IvwlewD6/woqlau+1PvISvRY= -github.com/lucasmenendez/apihandler v0.0.6/go.mod h1:gDwdzFu8GquIz0UkrA+UMjaYUQGtfDymm6i4iKEcM44= -github.com/lucasmenendez/apihandler v0.0.7 h1:OItUaGN5J+KrYFLZnQUNHXnOBP6HZyvlobyk1Jd7JkI= -github.com/lucasmenendez/apihandler v0.0.7/go.mod h1:gDwdzFu8GquIz0UkrA+UMjaYUQGtfDymm6i4iKEcM44= -github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe h1:iruDEfMl2E6fbMZ9s0scYfZQ84/6SPL6zC8ACM2oIL0= -github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= -github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= -github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= -github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY= -github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4= -github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= -github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= -github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d h1:splanxYIlg+5LfHAM6xpdFEAYOk8iySO56hMFq6uLyA= -github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA= -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -go.mongodb.org/mongo-driver v1.15.0 h1:rJCKC8eEliewXjZGf0ddURtl7tTVy1TK3bfl0gkUSLc= -go.mongodb.org/mongo-driver v1.15.0/go.mod h1:Vzb0Mk/pa7e6cWw85R4F/endUC3u0U9jGcNU603k65c= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= -golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= -golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= -golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U= -golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +github.com/lucasmenendez/apihandler v0.0.8 h1:xHBNqdg+/eKpmjSvcQIkfIWxhsBcQa5TpwRzU00KugU= +github.com/lucasmenendez/apihandler v0.0.8/go.mod h1:u19tqauhQwxXbR2rw9//dCdM7oNNqzzPbByHh7R1imU= +golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0= +golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= diff --git a/helpers/consts.go b/helpers/consts.go deleted file mode 100644 index a3c1223..0000000 --- a/helpers/consts.go +++ /dev/null @@ -1,50 +0,0 @@ -package helpers - -const ( - // TokenSeparator constant is the separator used to split the token into - // parts. It is a string with a value of "-". - TokenSeparator = "-" - // TokenQueryParam constant is the query parameter used to send the token in - // the request. It is a string with a value of "token". - TokenQueryParam = "token" - // AppSecretHeader constant is the header used to send the app secret in the - // request. It is a string with a value of "APP_SECRET". - AppSecretHeader = "APP_SECRET" - // DefaultAPIEndpoint constant is the default API endpoint used by the - // client. It is a string with a value of "https://api.simpleauth.link/". - DefaultAPIEndpoint = "https://api.simpleauth.link/" - // HealthCheckPath constant is the path used to check the health of the API - // server. It is a string with a value of "/health". - HealthCheckPath = "/health" - // AppEndpointPath constant is the path used to API endpoints related to - // apps. It is a string with a value of "/app". - AppEndpointPath = "/app" - // UserEndpointPath constant is the path used to API endpoints related to - // users. It is a string with a value of "/user". - UserEndpointPath = "/user" - // MinTokenDuration constant is the minimum duration allowed for a token to - // be valid, which is an integer with a value of 60 (seconds). - MinTokenDuration = 60 // seconds - // defaultUsersQuota constant is the default number of users allowed for an - // app, which is an integer with a value of 100. - DefaultUsersQuota = 100 // users - // UserIdSize constant is the size of the user id, which is an integer with a - // value of 4 (bytes). - UserIdSize = 4 - // AppIdSize constant is the size of the app id, which is an integer with a - // value of 8 (bytes). - AppIdSize = 8 - // EmailHashSize constant is the size of the email hash, which is an integer - // with a value of 4 (bytes). The email hash is used to generate the user id - // and the app id. - EmailHashSize = 4 - // AppNonceSize constant is the size of the app nonce, which is an integer - // with a value of 4 (bytes). The app nonce is used to generate the app id. - AppNonceSize = 4 - // SecretSize constant is the size of the secret, which is an integer with a - // value of 16 (bytes). - SecretSize = 16 - // TokenSize constant is the size of the token, which is an integer with a - // value of 8 (bytes). - TokenSize = 8 -) diff --git a/helpers/helpers.go b/helpers/helpers.go deleted file mode 100644 index ac0d67e..0000000 --- a/helpers/helpers.go +++ /dev/null @@ -1,110 +0,0 @@ -package helpers - -import ( - "crypto/sha256" - "encoding/hex" - "fmt" - "math/rand" - "net/url" - "strings" -) - -// EncodeUserToken function encodes the user information into a token and -// returns it. It receives the app id and the email of the user and returns the -// token and the user id. If the app id or the email are empty, it returns an -// error. The token is composed of three parts separated by a token separator. -// The first part is a random sequence of 8 bytes encoded as a hexadecimal -// string. The second part is the app id and the third part is the user id. The -// user id is generated hashing the email with a length of 4 bytes. The token -// is returned following the token format: -// -// [appId(8)]-[userId(8)]-[randomPart(16)] -func EncodeUserToken(appId, email string) (string, string, error) { - // check if the app id and email are not empty - if len(appId) == 0 || len(email) == 0 { - return "", "", fmt.Errorf("appId and email are required") - } - bToken := RandBytes(TokenSize) - hexToken := hex.EncodeToString(bToken) - // hash email - userId, err := Hash(email, UserIdSize) - if err != nil { - return "", "", err - } - return strings.Join([]string{appId, userId, hexToken}, TokenSeparator), userId, nil -} - -// DecodeUserToken function decodes the user information from the token provided -// and returns the app id and the user id. If the token is invalid, it returns -// an error. It splits the provided token by the token separator and returns the -// second and third parts, which are the app id and the user id respectively, -// following the token format: -// -// [appId(8)]-[userId(8)]-[randomPart(16)] -func DecodeUserToken(token string) (string, string, error) { - tokenParts := strings.Split(token, TokenSeparator) - if len(tokenParts) != 3 { - return "", "", fmt.Errorf("invalid token") - } - return tokenParts[0], tokenParts[1], nil -} - -// RandBytes generates a random byte slice of length n. It returns nil if n is -// less than 1. -func RandBytes(n int) []byte { - if n < 1 { - return nil - } - b := make([]byte, n) - for i := 0; i < n; { - val := rand.Uint64() - for j := 0; j < 8 && i < n; j++ { - b[i] = byte(val & 0xff) - val >>= 8 - i++ - } - } - return b -} - -// Hash generates a hash of the input string using SHA-256 algorithm. The n -// parameter allows to truncate the hash to n bytes. It returns the hash as a -// hexadecimal string. The resulting string will have a length of 2*n. If n is -// less than 1 or greater than the hash length, the full hash will be returned. -// If the input string is empty, it returns an empty string. If something fails -// during the hashing process, it returns an error. -func Hash(input string, n int) (string, error) { - if input == "" { - return "", nil - } - hash := sha256.New() - if _, err := hash.Write([]byte(input)); err != nil { - return "", err - } - bHash := hash.Sum(nil) - if n > 0 && n < len(bHash) { - bHash = bHash[:n] - } - return hex.EncodeToString(bHash), nil -} - -// SafeURL function returns a safe URL string from the provided URL. It returns -// an empty string if the URL is nil. The resulting string will have the format: -// scheme://host/path#fragment?query. If the URL has no path, query or fragment, -// they will be omitted. The query parameters will be encoded. -func SafeURL(url *url.URL) string { - if url == nil { - return "" - } - strURL := fmt.Sprintf("%s://%s", url.Scheme, url.Host) - if url.Path != "" { - strURL += url.Path - } - if url.Fragment != "" { - strURL += fmt.Sprintf("#%s", url.Fragment) - } - if encoded := url.Query().Encode(); encoded != "" { - strURL += fmt.Sprintf("?%s", encoded) - } - return strURL -} diff --git a/internal/error.go b/internal/error.go new file mode 100644 index 0000000..852d61c --- /dev/null +++ b/internal/error.go @@ -0,0 +1,41 @@ +package internal + +import "fmt" + +// Error represents an error with a message and a trace. It is a custom struct +// that can be used to wrap errors with additional information. +type Error struct { + msg string + trace error +} + +// NewErr creates a new Error instance with the given message. It is a helper +// function that simplifies the creation of Error instances in other packages. +func NewErr(msg string) *Error { + return &Error{msg: msg} +} + +// Error returns the error message. If the error has a trace, it is appended to +// the message. The error implements the error interface. +func (e *Error) Error() string { + err := fmt.Errorf("%s", e.msg) + if e.trace != nil { + err = fmt.Errorf("%s: %w", err, e.trace) + } + return err.Error() +} + +// With adds an error as a trace to the error. It is a helper function that +// simplifies the addition of traces to Error instances in other packages. +func (e *Error) With(err error) *Error { + e.trace = err + return e +} + +// Withf adds a formatted error message as a trace to the error. It is a helper +// function that simplifies the addition of formatted traces to Error instances +// in other packages. +func (e *Error) Withf(tmpl string, args ...any) *Error { + e.trace = fmt.Errorf(tmpl, args...) + return e +} diff --git a/internal/error_test.go b/internal/error_test.go new file mode 100644 index 0000000..6cb6e8d --- /dev/null +++ b/internal/error_test.go @@ -0,0 +1,49 @@ +package internal + +import ( + "errors" + "fmt" + "testing" +) + +func TestNewErr(t *testing.T) { + err := NewErr("test message") + if err.msg != "test message" { + t.Errorf("expected message 'test message', got '%s'", err.msg) + } + if err.trace != nil { + t.Errorf("expected nil trace, got '%v'", err.trace) + } +} + +func TestError(t *testing.T) { + err := NewErr("test message") + if err.Error() != "test message" { + t.Errorf("expected 'test message', got '%s'", err.Error()) + } + + wrappedErr := errors.New("wrapped error") + _ = err.With(wrappedErr) + expected := "test message: wrapped error" + if err.Error() != expected { + t.Errorf("expected '%s', got '%s'", expected, err.Error()) + } +} + +func TestWith(t *testing.T) { + err := NewErr("test message") + wrappedErr := errors.New("wrapped error") + _ = err.With(wrappedErr) + if err.trace != wrappedErr { + t.Errorf("expected trace '%v', got '%v'", wrappedErr, err.trace) + } +} + +func TestWithf(t *testing.T) { + err := NewErr("test message") + _ = err.Withf("formatted %s", "error") + expectedTrace := fmt.Errorf("formatted %s", "error").Error() + if err.trace.Error() != expectedTrace { + t.Errorf("expected trace '%s', got '%s'", expectedTrace, err.trace.Error()) + } +} diff --git a/internal/fakesmtpserver/server.go b/internal/fakesmtpserver/server.go new file mode 100644 index 0000000..dee96c4 --- /dev/null +++ b/internal/fakesmtpserver/server.go @@ -0,0 +1,134 @@ +package fakesmtpserver + +// fakesmtpserver package provides a simple SMTP server for testing purposes. +// It allows you to simulate an SMTP server that can receive emails and store +// them in a channel. This is useful for testing email sending functionality +// in applications without needing to set up a real SMTP server. The server +// can be started and stopped, and it handles basic SMTP commands like HELO, +// MAIL FROM, RCPT TO, and DATA. It also provides a way to retrieve the +// received emails from the inbox channel. + +import ( + "bufio" + "context" + "fmt" + "net" + "strings" + "sync" +) + +// FakeSMTPServer represents a simple SMTP testing server. +type FakeSMTPServer struct { + addr string + inbox chan string + listener net.Listener + mu sync.Mutex // Mutex to protect listener +} + +// NewServer creates a new FakeSMTPServer instance that listens on the given +// address and port and stores the received emails in the inbox channel +// provided. +func NewServer(addr string, port int, inbox chan string) *FakeSMTPServer { + return &FakeSMTPServer{ + addr: fmt.Sprintf("%s:%d", addr, port), + inbox: inbox, + } +} + +// Start method launches the test SMTP server. +func (s *FakeSMTPServer) Start(ctx context.Context) error { + var err error + s.mu.Lock() + s.listener, err = net.Listen("tcp", s.addr) + s.mu.Unlock() + if err != nil { + return err + } + go func() { + for { + select { + case <-ctx.Done(): + // use Stop to safely close the listener + s.Stop() + return + default: + // copy listener under lock + s.mu.Lock() + listener := s.listener + s.mu.Unlock() + if listener == nil { + return + } + conn, err := listener.Accept() + if err != nil { + continue + } + go s.handleConn(conn) + } + } + }() + return nil +} + +// Stop method shuts down the test SMTP server. +func (s *FakeSMTPServer) Stop() { + // copy listener under lock + s.mu.Lock() + listener := s.listener + // set listener to nil under lock and unlock + s.listener = nil + s.mu.Unlock() + // close the listener if it is not nil + if listener != nil { + listener.Close() + } +} + +func (s *FakeSMTPServer) handleConn(conn net.Conn) { + defer conn.Close() + reader := bufio.NewReader(conn) + // send greeting + fmt.Fprintf(conn, "220 Fake SMTP Service Ready\r\n") + var dataBuilder strings.Builder + inData := false + // read incoming data + for { + line, err := reader.ReadString('\n') + if err != nil { + return + } + line = strings.TrimRight(line, "\r\n") + // check if we are in the data section + if inData { + if line == "." { + inData = false + // send back a confirmation and store the data + fmt.Fprintf(conn, "250 OK\r\n") + s.inbox <- dataBuilder.String() + dataBuilder.Reset() + continue + } + dataBuilder.WriteString(line + "\n") + continue + } + // simple command handling + switch { + case strings.HasPrefix(line, "HELO"), strings.HasPrefix(line, "EHLO"): + fmt.Fprintf(conn, "250 Hello\r\n") + case strings.HasPrefix(line, "MAIL FROM:"): + fmt.Fprintf(conn, "250 OK\r\n") + case strings.HasPrefix(line, "RCPT TO:"): + fmt.Fprintf(conn, "250 OK\r\n") + case strings.HasPrefix(line, "DATA"): + // prepare to receive data + fmt.Fprintf(conn, "354 End data with .\r\n") + inData = true + case strings.HasPrefix(line, "QUIT"): + // close the connection + fmt.Fprintf(conn, "221 Bye\r\n") + return + default: + fmt.Fprintf(conn, "250 OK\r\n") + } + } +} diff --git a/internal/fakesmtpserver/server_test.go b/internal/fakesmtpserver/server_test.go new file mode 100644 index 0000000..de6e351 --- /dev/null +++ b/internal/fakesmtpserver/server_test.go @@ -0,0 +1,239 @@ +package fakesmtpserver + +import ( + "bufio" + "context" + "net" + "strconv" + "strings" + "testing" + "time" +) + +func getFreePort() (string, error) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return "", err + } + defer listener.Close() + return listener.Addr().String(), nil +} + +func splitHostPort(address string) (string, int, error) { + host, portStr, err := net.SplitHostPort(address) + if err != nil { + return "", 0, err + } + port, err := strconv.Atoi(portStr) + if err != nil { + return "", 0, err + } + return host, port, nil +} + +func TestFakeSMTPServer(t *testing.T) { + inbox := make(chan string, 1) + address, err := getFreePort() + if err != nil { + t.Fatalf("Failed to get free port: %v", err) + } + host, port, err := splitHostPort(address) + if err != nil { + t.Fatalf("Failed to split host and port: %v", err) + } + server := NewServer(host, port, inbox) + ctx, cancel := context.WithCancel(t.Context()) // Fixed incorrect t.Context() + defer cancel() + + errChan := make(chan error, 1) // Channel to capture errors from the goroutine + + // Start the server + go func() { + if err := server.Start(ctx); err != nil { + errChan <- err // Send error to the channel + } + close(errChan) // Close the channel when done + }() + time.Sleep(100 * time.Millisecond) // Give the server time to start + + // Check for errors from the goroutine + select { + case err := <-errChan: + if err != nil { + t.Fatalf("Failed to start server: %v", err) + } + default: + // No error, continue with the test + } + + // Connect to the server and send an email + conn, err := net.Dial("tcp", address) + if err != nil { + t.Fatalf("Failed to connect to server: %v", err) + } + defer conn.Close() + + reader := bufio.NewReader(conn) + writer := bufio.NewWriter(conn) + + // Read greeting + if greeting, _ := reader.ReadString('\n'); !strings.HasPrefix(greeting, "220") { + t.Fatalf("Expected greeting, got: %s", greeting) + } + + // Send HELO + if _, err := writer.WriteString("HELO localhost\r\n"); err != nil { + t.Fatalf("Failed to write HELO command: %v", err) + } + writer.Flush() + if response, _ := reader.ReadString('\n'); !strings.HasPrefix(response, "250") { + t.Fatalf("Expected HELO response, got: %s", response) + } + + // Send MAIL FROM + if _, err := writer.WriteString("MAIL FROM:\r\n"); err != nil { + t.Fatalf("Failed to write MAIL FROM command: %v", err) + } + writer.Flush() + if response, _ := reader.ReadString('\n'); !strings.HasPrefix(response, "250") { + t.Fatalf("Expected MAIL FROM response, got: %s", response) + } + + // Send RCPT TO + if _, err := writer.WriteString("RCPT TO:\r\n"); err != nil { + t.Fatalf("Failed to write RCPT TO command: %v", err) + } + writer.Flush() + if response, _ := reader.ReadString('\n'); !strings.HasPrefix(response, "250") { + t.Fatalf("Expected RCPT TO response, got: %s", response) + } + + // Send DATA + if _, err := writer.WriteString("DATA\r\n"); err != nil { + t.Fatalf("Failed to write DATA command: %v", err) + } + writer.Flush() + if response, _ := reader.ReadString('\n'); !strings.HasPrefix(response, "354") { + t.Fatalf("Expected DATA response, got: %s", response) + } + + // Send email content + if _, err := writer.WriteString("Subject: Test Email\r\n\r\nThis is a test email.\r\n.\r\n"); err != nil { + t.Fatalf("Failed to write email content: %v", err) + } + writer.Flush() + if response, _ := reader.ReadString('\n'); !strings.HasPrefix(response, "250") { + t.Fatalf("Expected email content response, got: %s", response) + } + + // Send QUIT + if _, err := writer.WriteString("QUIT\r\n"); err != nil { + t.Fatalf("Failed to write QUIT command: %v", err) + } + writer.Flush() + if response, _ := reader.ReadString('\n'); !strings.HasPrefix(response, "221") { + t.Fatalf("Expected QUIT response, got: %s", response) + } + + // Verify email content in inbox + select { + case email := <-inbox: + if !strings.Contains(email, "Subject: Test Email") || !strings.Contains(email, "This is a test email.") { + t.Fatalf("Unexpected email content: %s", email) + } + case <-time.After(1 * time.Second): + t.Fatal("Timeout waiting for email in inbox") + } + + // Stop the server + server.Stop() +} + +func TestFakeSMTPServer_UnsupportedCommand(t *testing.T) { + inbox := make(chan string, 1) + address, err := getFreePort() + if err != nil { + t.Fatalf("Failed to get free port: %v", err) + } + host, port, err := splitHostPort(address) + if err != nil { + t.Fatalf("Failed to split host and port: %v", err) + } + server := NewServer(host, port, inbox) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + errChan := make(chan error, 1) + + // Start the server + go func() { + if err := server.Start(ctx); err != nil { + errChan <- err + } + close(errChan) + }() + time.Sleep(100 * time.Millisecond) + + // Check for errors from the goroutine + select { + case err := <-errChan: + if err != nil { + t.Fatalf("Failed to start server: %v", err) + } + default: + } + + // Connect to the server and send an unsupported command + conn, err := net.Dial("tcp", address) + if err != nil { + t.Fatalf("Failed to connect to server: %v", err) + } + defer conn.Close() + + reader := bufio.NewReader(conn) + writer := bufio.NewWriter(conn) + + // Read greeting + if greeting, _ := reader.ReadString('\n'); !strings.HasPrefix(greeting, "220") { + t.Fatalf("Expected greeting, got: %s", greeting) + } + + // Send unsupported command + if _, err := writer.WriteString("FOO BAR\r\n"); err != nil { + t.Fatalf("Failed to write unsupported command: %v", err) + } + writer.Flush() + if response, _ := reader.ReadString('\n'); !strings.HasPrefix(response, "250") { + t.Fatalf("Expected default response, got: %s", response) + } + + // Stop the server + server.Stop() +} + +func TestFakeSMTPServer_BadAddress(t *testing.T) { + inbox := make(chan string, 1) + server := NewServer("invalid-address", 2527, inbox) // Added a dummy port + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + errChan := make(chan error, 1) + + // Start the server + go func() { + if err := server.Start(ctx); err != nil { + errChan <- err + } + close(errChan) + }() + time.Sleep(100 * time.Millisecond) + + // Check for errors from the goroutine + select { + case err := <-errChan: + if err == nil { + t.Fatalf("Expected error starting server with bad address, got nil") + } + default: + } +} diff --git a/internal/osflag/osflag.go b/internal/osflag/osflag.go new file mode 100644 index 0000000..074e298 --- /dev/null +++ b/internal/osflag/osflag.go @@ -0,0 +1,250 @@ +package osflag + +// osflag package provides a way to manage command line flags and environment +// variables in Go applications. It allows for the creation of command line +// flags that can also be overwritten by environment variables. By default it +// loads a `.env` file, but this can be overridden by passing an `WithEnvFile` +// option to the `Parse` method. It also checks for required flags and ensures +// that they are set before parsing the command line arguments. + +import ( + "bufio" + "flag" + "fmt" + "os" + "strings" + "time" +) + +// Options is a struct that holds options for the Parse method. +type Options struct { + envFile string +} + +// WithEnvFile function creates an Options instance with the specified +// envFile path. If the path is empty, it returns nil. This function is used +// to specify a custom env file path when calling the Parse method. +func WithEnvFile(path string) *Options { + if path == "" { + return nil + } + return &Options{envFile: path} +} + +// osflag is a struct that holds the name, environment variable, and +// required mark of a flag. It is used to manage command line flags +// and their corresponding environment variables. +type osflag struct { + name string + env string + required bool +} + +// OsFlagSet is a struct that embeds flag.FlagSet and adds support for env +// variables. It allows for the creation of command line flags that can also +// be overwritten by environment variables. By default it loads `.env` file, +// but this can be overridden by passing an WithEnvFile option to the Parse +// method. It also checks for required flags and ensures that they are set +// before parsing the command line arguments. +type OsFlagSet struct { + *flag.FlagSet + flags map[string]osflag + parsed bool +} + +// CommandLine is the default OsFlagSet instance. +var CommandLine *OsFlagSet + +// init initializes the CommandLine variable with a new OsFlagSet instance. +func init() { + CommandLine = new(OsFlagSet) + if len(os.Args) == 0 { + CommandLine.FlagSet = flag.NewFlagSet("", flag.ExitOnError) + } else { + CommandLine.FlagSet = flag.NewFlagSet(os.Args[0], flag.ExitOnError) + } + CommandLine.flags = make(map[string]osflag) +} + +// BoolVar method registers a boolean flag with the given name, env variable, +// default value, usage string, and required mark. +func (of *OsFlagSet) BoolVar(p *bool, env, name string, value bool, usage string, required bool) { + of.flags[name] = osflag{name, env, required} + of.FlagSet.BoolVar(p, name, value, usage) +} + +// DurationVar method registers a duration flag with the given name, env +// variable, default value, usage string, and required mark. +func (of *OsFlagSet) DurationVar(p *time.Duration, env, name string, value time.Duration, usage string, required bool) { + of.flags[name] = osflag{name, env, required} + of.FlagSet.DurationVar(p, name, value, usage) +} + +// Float64Var method registers a float64 flag with the given name, env variable, +// default value, usage string, and required mark. +func (of *OsFlagSet) Float64Var(p *float64, env, name string, value float64, usage string, required bool) { + of.flags[name] = osflag{name, env, required} + of.FlagSet.Float64Var(p, name, value, usage) +} + +// IntVar method registers an int flag with the given name, env variable, +// default value, usage string, and required mark. +func (of *OsFlagSet) IntVar(p *int, env, name string, value int, usage string, required bool) { + of.flags[name] = osflag{name, env, required} + of.FlagSet.IntVar(p, name, value, usage) +} + +// StringVar method registers a string flag with the given name, env variable, +// default value, usage string, and required mark. +func (of *OsFlagSet) StringVar(p *string, env, name string, value string, usage string, required bool) { + of.flags[name] = osflag{name, env, required} + of.FlagSet.StringVar(p, name, value, usage) +} + +// UintVar method registers a uint flag with the given name, env variable, +// default value, usage string, and required mark. +func (of *OsFlagSet) UintVar(p *uint, env, name string, value uint, usage string, required bool) { + of.flags[name] = osflag{name, env, required} + of.FlagSet.UintVar(p, name, value, usage) +} + +// Parse method parses the command line arguments and loads the environment +// variables from the specified env file. It checks if all required flags are +// set and it overwrites the command line flags with the values from the env +// variables if they are set. It returns an error if any required flags are +// not set or if there is an error loading the env file. +func (of *OsFlagSet) Parse(opts *Options) error { + if err := of.FlagSet.Parse(os.Args[1:]); err != nil { + return err + } + // load the env file + envFile := ".env" + if opts != nil && opts.envFile != "" { + envFile = opts.envFile + } + if err := loadEnv(envFile); err != nil { + return fmt.Errorf("failed to load env file: %w", err) + } + // check if all required flags are set + for name, osf := range of.flags { + if envValue := os.Getenv(osf.env); envValue != "" { + if err := of.FlagSet.Set(name, envValue); err != nil { + return fmt.Errorf("failed to set flag %s from env: %w", name, err) + } + } + // check if the flag is required and not set + if osf.required { + f := of.FlagSet.Lookup(name) + if f == nil || f.Value.String() == "" { + return fmt.Errorf("required flag %s is not set", name) + } + } + } + of.parsed = of.FlagSet.Parsed() + return nil +} + +// Parsed method returns true if the command line arguments have been parsed. +func (of *OsFlagSet) Parsed() bool { + return of.parsed +} + +// PrintDefaults method prints the default values of all flags. +func (of *OsFlagSet) PrintDefaults() { + of.FlagSet.PrintDefaults() +} + +// BoolVar method registers a boolean flag with the given name, env variable, +// default value, usage string, and required mark. +func BoolVar(p *bool, env, name string, value bool, usage string, required bool) { + CommandLine.BoolVar(p, env, name, value, usage, required) +} + +// DurationVar method registers a duration flag with the given name, env +// variable, default value, usage string, and required mark. +func DurationVar(p *time.Duration, env, name string, value time.Duration, usage string, required bool) { + CommandLine.DurationVar(p, env, name, value, usage, required) +} + +// Float64Var method registers a float64 flag with the given name, env variable, +// default value, usage string, and required mark. +func Float64Var(p *float64, env, name string, value float64, usage string, required bool) { + CommandLine.Float64Var(p, env, name, value, usage, required) +} + +// IntVar method registers an int flag with the given name, env variable, +// default value, usage string, and required mark. +func IntVar(p *int, env, name string, value int, usage string, required bool) { + CommandLine.IntVar(p, env, name, value, usage, required) +} + +// StringVar method registers a string flag with the given name, env variable, +// default value, usage string, and required mark. +func StringVar(p *string, env, name string, value string, usage string, required bool) { + CommandLine.StringVar(p, env, name, value, usage, required) +} + +// UintVar method registers a uint flag with the given name, env variable, +// default value, usage string, and required mark. +func UintVar(p *uint, env, name string, value uint, usage string, required bool) { + CommandLine.UintVar(p, env, name, value, usage, required) +} + +// Parse method parses the command line arguments and loads the environment +// variables from the specified env file. It checks if all required flags are +// set and it overwrites the command line flags with the values from the env +// variables if they are set. It returns an error if any required flags are +// not set or if there is an error loading the env file. +func Parse(opts *Options) error { + return CommandLine.Parse(opts) +} + +// Parsed method returns true if the command line arguments have been parsed. +func Parsed() bool { + return CommandLine.parsed +} + +// PrintDefaults method prints the default values of all flags. +func PrintDefaults() { + CommandLine.PrintDefaults() +} + +// loadEnv function loads environment variables from a file. If the file does +// not exist, it returns nil and does not raise an error. It reads the file +// line by line, ignoring empty lines and comments. It sets the environment +// variables in the current process using os.Setenv. It removes any "export " +// prefix and surrounding quotes from the variable assignments. It returns +// an error if there is an issue opening or reading the file (different from +// the file not existing). +func loadEnv(path string) error { + envFile, err := os.Open(path) + if err != nil { + // if the file does not exist, return nil + if os.IsNotExist(err) { + return nil + } + return fmt.Errorf("failed to open env file: %w", err) + } + defer envFile.Close() + // create a line scanner + scanner := bufio.NewScanner(envFile) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.HasPrefix(line, "#") { + continue // skip empty lines and comments + } + // remove "export " prefix if present + line = strings.TrimPrefix(line, "export ") + // split on the first '=' character + parts := strings.SplitN(line, "=", 2) + if len(parts) != 2 { + continue // or return an error if preferred + } + key := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + value = strings.Trim(value, "\"'") // remove surrounding quotes + // set var in the current env + os.Setenv(key, value) + } + return nil +} diff --git a/internal/osflag/osflag_test.go b/internal/osflag/osflag_test.go new file mode 100644 index 0000000..15a2a46 --- /dev/null +++ b/internal/osflag/osflag_test.go @@ -0,0 +1,160 @@ +package osflag + +import ( + "flag" + "os" + "testing" + "time" +) + +func resetCommandLine() { + CommandLine = new(OsFlagSet) + CommandLine.FlagSet = flag.NewFlagSet("", flag.ExitOnError) + CommandLine.flags = make(map[string]osflag) + + // Filter out test framework flags + os.Args = os.Args[:1] +} + +func TestBoolVar(t *testing.T) { + resetCommandLine() + var flagValue bool + os.Setenv("TEST_BOOL", "true") + defer os.Unsetenv("TEST_BOOL") + + CommandLine.BoolVar(&flagValue, "TEST_BOOL", "boolFlag", false, "A boolean flag", false) + if err := CommandLine.Parse(nil); err != nil { + t.Fatalf("Failed to parse command line: %v", err) + } + + if !flagValue { + t.Errorf("Expected true, got %v", flagValue) + } +} + +func TestDurationVar(t *testing.T) { + resetCommandLine() + var flagValue time.Duration + os.Setenv("TEST_DURATION", "5s") + defer os.Unsetenv("TEST_DURATION") + + CommandLine.DurationVar(&flagValue, "TEST_DURATION", "durationFlag", 0, "A duration flag", false) + if err := CommandLine.Parse(nil); err != nil { + t.Fatalf("Failed to parse command line: %v", err) + } + + if flagValue != 5*time.Second { + t.Errorf("Expected 5s, got %v", flagValue) + } +} + +func TestFloat64Var(t *testing.T) { + resetCommandLine() + var flagValue float64 + os.Setenv("TEST_FLOAT", "3.14") + defer os.Unsetenv("TEST_FLOAT") + + CommandLine.Float64Var(&flagValue, "TEST_FLOAT", "floatFlag", 0.0, "A float flag", false) + if err := CommandLine.Parse(nil); err != nil { + t.Fatalf("Failed to parse command line: %v", err) + } + + if flagValue != 3.14 { + t.Errorf("Expected 3.14, got %v", flagValue) + } +} + +func TestIntVar(t *testing.T) { + resetCommandLine() + var flagValue int + os.Setenv("TEST_INT", "42") + defer os.Unsetenv("TEST_INT") + + CommandLine.IntVar(&flagValue, "TEST_INT", "intFlag", 0, "An int flag", false) + if err := CommandLine.Parse(nil); err != nil { + t.Fatalf("Failed to parse command line: %v", err) + } + + if flagValue != 42 { + t.Errorf("Expected 42, got %v", flagValue) + } +} + +func TestStringVar(t *testing.T) { + resetCommandLine() + var flagValue string + os.Setenv("TEST_STRING", "hello") + defer os.Unsetenv("TEST_STRING") + + CommandLine.StringVar(&flagValue, "TEST_STRING", "stringFlag", "default", "A string flag", false) + if err := CommandLine.Parse(nil); err != nil { + t.Fatalf("Failed to parse command line: %v", err) + } + + if flagValue != "hello" { + t.Errorf("Expected 'hello', got %v", flagValue) + } +} + +func TestUintVar(t *testing.T) { + resetCommandLine() + var flagValue uint + os.Setenv("TEST_UINT", "100") + defer os.Unsetenv("TEST_UINT") + + CommandLine.UintVar(&flagValue, "TEST_UINT", "uintFlag", 0, "A uint flag", false) + if err := CommandLine.Parse(nil); err != nil { + t.Fatalf("Failed to parse command line: %v", err) + } + + if flagValue != 100 { + t.Errorf("Expected 100, got %v", flagValue) + } +} + +func TestRequiredFlag(t *testing.T) { + resetCommandLine() + var flagValue string + CommandLine.StringVar(&flagValue, "", "requiredFlag", "", "A required flag", true) + + if err := CommandLine.Parse(nil); err == nil { + t.Errorf("Expected error for missing required flag, got nil") + } +} + +func TestDefaultValues(t *testing.T) { + resetCommandLine() + var flagValue string + CommandLine.StringVar(&flagValue, "", "defaultFlag", "defaultValue", "A flag with a default value", false) + if err := CommandLine.Parse(nil); err != nil { + t.Fatalf("Failed to parse command line: %v", err) + } + + if flagValue != "defaultValue" { + t.Errorf("Expected 'defaultValue', got %v", flagValue) + } +} + +func TestLoadEnv(t *testing.T) { + resetCommandLine() + // try to load a non-existing env file (should not error) + if err := loadEnv("non_existing.env"); err != nil { + t.Fatalf("Expected no error for non-existing env file, got: %v", err) + } + // create .env file + envFileContent := []byte("TEST_ENV=envValue") + envFilePath := ".env" + if err := os.WriteFile(envFilePath, envFileContent, 0o644); err != nil { + t.Fatalf("Failed to create env file: %v", err) + } + defer os.Remove(envFilePath) + // parse flags and check the value + var flagValue string + CommandLine.StringVar(&flagValue, "TEST_ENV", "envFlag", "defaultValue", "A flag with an env variable", false) + if err := CommandLine.Parse(nil); err != nil { + t.Fatalf("Failed to parse command line: %v", err) + } + if flagValue != "envValue" { + t.Errorf("Expected 'envValue', got %v", flagValue) + } +} diff --git a/notification/email/emailqueue.go b/notification/email/emailqueue.go new file mode 100644 index 0000000..f381db3 --- /dev/null +++ b/notification/email/emailqueue.go @@ -0,0 +1,221 @@ +package email + +import ( + "bytes" + "context" + "fmt" + "mime/multipart" + "net/mail" + "net/smtp" + "net/textproto" + "sync" + + "github.com/simpleauthlink/authapi/notification" +) + +// defaultSendRetries is the default number of retries to send the email. +const defaultSendRetries = 3 + +// EmailConfig struct represents the email configuration that is needed to send +// an email using and SMTP server. It includes the email address (used as the +// sender address but also as the username for the SMTP server), the email +// server hostname, its port and the password. +type EmailConfig struct { + FromName string + FromAddress string + SMTPUsername string + SMTPPassword string + SMTPServer string + SMTPPort int + Retries int + ErrorCh chan error +} + +// Valid method checks if the email configuration is valid. It returns true if +// the sender name, the SMTP server and its port are not empty, and the sender +// email address is valid. It also sets the number of retries to the default +// value if it is not set. +func (cfg *EmailConfig) Valid() bool { + if cfg.FromName == "" || cfg.SMTPServer == "" || cfg.SMTPPort == 0 { + return false + } + if cfg.Retries == 0 { + cfg.Retries = defaultSendRetries + } + _, err := mail.ParseAddress(cfg.FromAddress) + return err == nil +} + +// EmailQueue struct represents the email queue. It includes the context and the +// cancel function to stop the queue, the configuration of the server to send +// the email, the list of emails to send, and the waiter to wait for the +// background process to finish. +type EmailQueue struct { + ctx context.Context + cancel context.CancelFunc + cfg *EmailConfig + auth smtp.Auth + items []*notification.Notification + itemsMtx sync.Mutex + waiter sync.WaitGroup + errCh chan error +} + +// NewEmailQueue creates a new EmailQueue with the provided configuration. +func NewEmailQueue(ctx context.Context, cfg *EmailConfig) (*EmailQueue, error) { + // check if the configuration is valid + if !cfg.Valid() { + return nil, ErrInvalidConfig + } + // init the email queue + internalCtx, cancel := context.WithCancel(ctx) + eq := &EmailQueue{ + ctx: internalCtx, + cancel: cancel, + cfg: cfg, + items: []*notification.Notification{}, + errCh: cfg.ErrorCh, + } + // init SMTP auth + if cfg.SMTPUsername != "" && cfg.SMTPPassword != "" { + eq.auth = smtp.PlainAuth("", cfg.SMTPUsername, cfg.SMTPPassword, cfg.SMTPServer) + } + // return the email queue + return eq, nil +} + +// Start method starts the email queue. It listens for new emails in the queue +// and sends them using the provided configuration. +func (eq *EmailQueue) Start() { + eq.waiter.Add(1) + go func() { + defer eq.waiter.Done() + for { + select { + case <-eq.ctx.Done(): + return + default: + e, ok := eq.Pop() + if !ok { + continue + } + if err := eq.Send(e); err != nil { + if eq.errCh != nil { + eq.errCh <- err + } + } + } + } + }() +} + +// Stop method stops the email queue. +func (eq *EmailQueue) Stop() { + eq.cancel() + eq.waiter.Wait() +} + +// Pop method removes the first email in the queue and returns it. +func (eq *EmailQueue) Pop() (notification.Notification, bool) { + eq.itemsMtx.Lock() + defer eq.itemsMtx.Unlock() + if len(eq.items) == 0 { + return notification.Notification{}, false + } + e := eq.items[0] + eq.items = eq.items[1:] + return *e, true +} + +// Push method adds a new email to the queue. +func (eq *EmailQueue) Push(n notification.Notification) error { + // check if the email is valid + if !n.Valid() { + return ErrInvalidEmail + } + eq.itemsMtx.Lock() + eq.items = append(eq.items, &n) + eq.itemsMtx.Unlock() + return nil +} + +// Send method sends the email using the queue configuration. It uses the +// email address as the sender address and the username for the SMTP server. +// It composes the email message, creates the auth object with the email +// credentials, the server string with the host and the port, and the receipts. +// Finally, it sends the email. If something fails during the process, it +// returns an error. It can be used even the queue is not started. +func (eq *EmailQueue) Send(n notification.Notification) error { + // check if the email is valid + if !n.Valid() { + return ErrInvalidEmail + } + // compose the email body + body, err := eq.composeBody(n) + if err != nil { + return ErrComposeEmail.With(err) + } + // create the server string with the host and the port and the receipts + server := fmt.Sprintf("%s:%d", eq.cfg.SMTPServer, eq.cfg.SMTPPort) + receipts := []string{n.Params.To} + // send the email + for i := 0; i < eq.cfg.Retries; i++ { + if err = smtp.SendMail(server, eq.auth, eq.cfg.FromAddress, receipts, body); err == nil { + break + } + } + if err != nil { + return ErrSendEmail.With(err) + } + return nil +} + +// composeBody creates the email body with the message data. It creates a +// multipart email with a plain text and an HTML part. It returns the email +// content as a byte slice or an error if the body could not be composed. +func (eq *EmailQueue) composeBody(n notification.Notification) ([]byte, error) { + // parse 'to' email address + to, err := mail.ParseAddress(n.Params.To) + if err != nil { + return nil, ErrParseAddress.With(err) + } + // create email headers + var headers bytes.Buffer + boundary := "----=_Part_0_123456789.123456789" + headers.WriteString(fmt.Sprintf("From: %s\r\n", eq.cfg.FromAddress)) + headers.WriteString(fmt.Sprintf("To: %s\r\n", to.String())) + headers.WriteString(fmt.Sprintf("Subject: %s\r\n", n.Params.Subject)) + headers.WriteString("MIME-Version: 1.0\r\n") + headers.WriteString(fmt.Sprintf("Content-Type: multipart/alternative; boundary=\"%s\"\r\n", boundary)) + headers.WriteString("\r\n") // blank line between headers and body + // create multipart writer + var body bytes.Buffer + writer := multipart.NewWriter(&body) + if err := writer.SetBoundary(boundary); err != nil { + return nil, ErrSetBoundary.With(err) + } + // plain text part + textPart, _ := writer.CreatePart(textproto.MIMEHeader{ + "Content-Type": {"text/plain; charset=\"UTF-8\""}, + "Content-Transfer-Encoding": {"7bit"}, + }) + if _, err := textPart.Write(n.PlainBody); err != nil { + return nil, ErrWriteBody.With(err) + } + // HTML part + htmlPart, _ := writer.CreatePart(textproto.MIMEHeader{ + "Content-Type": {"text/html; charset=\"UTF-8\""}, + "Content-Transfer-Encoding": {"7bit"}, + }) + if _, err := htmlPart.Write(n.Body); err != nil { + return nil, ErrWriteHTMLBody.With(err) + } + if err := writer.Close(); err != nil { + return nil, ErrCloseEmailWriter.With(err) + } + // combine headers and body and return the content + var email bytes.Buffer + email.Write(headers.Bytes()) + email.Write(body.Bytes()) + return email.Bytes(), nil +} diff --git a/notification/email/emailqueue_test.go b/notification/email/emailqueue_test.go new file mode 100644 index 0000000..8c3ad2a --- /dev/null +++ b/notification/email/emailqueue_test.go @@ -0,0 +1,302 @@ +package email + +import ( + "context" + "strings" + "testing" + "time" + + "github.com/simpleauthlink/authapi/internal/fakesmtpserver" + "github.com/simpleauthlink/authapi/notification" +) + +const ( + testServerAddr = "127.0.0.1" + testServerPort = 2525 + testSenderName = "Test Sender" + testSender = "sender@testmail.com" + testReceiver = "receiver@testmail.com" + testSubject = "Test email" + testBody = "This is a test email" + testHTMLBody = "

This is a test email

" +) + +var inboxChan = make(chan string, 1) + +func TestMain(m *testing.M) { + defer close(inboxChan) + // create context with cancel + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // start test SMTP server to receive the email + testSrv := fakesmtpserver.NewServer(testServerAddr, testServerPort, inboxChan) + if err := testSrv.Start(ctx); err != nil { + panic(err) + } + defer testSrv.Stop() + m.Run() +} + +func TestValidEmail(t *testing.T) { + if !(¬ification.Notification{ + Params: notification.NotificationParams{ + To: testReceiver, + Subject: testSubject, + }, + Body: nil, + PlainBody: []byte(testBody), + }).Valid() { + t.Error("expected email to be valid") + } + if !(¬ification.Notification{ + Params: notification.NotificationParams{ + To: testReceiver, + Subject: testSubject, + }, + Body: []byte(testBody), + PlainBody: nil, + }).Valid() { + t.Error("expected email to be valid") + } + if (¬ification.Notification{ + Params: notification.NotificationParams{ + To: testReceiver, + Subject: "", + }, + Body: nil, + PlainBody: []byte(testBody), + }).Valid() { + t.Error("expected email to be invalid") + } + if (¬ification.Notification{ + Params: notification.NotificationParams{ + To: "", + Subject: testSubject, + }, + Body: nil, + PlainBody: []byte(testBody), + }).Valid() { + t.Error("expected email to be invalid") + } + if (¬ification.Notification{ + Params: notification.NotificationParams{ + To: "invalidEmail", + Subject: testSubject, + }, + Body: nil, + PlainBody: []byte(testBody), + }).Valid() { + t.Error("expected email to be invalid") + } + if (¬ification.Notification{}).Valid() { + t.Error("expected email to be invalid") + } +} + +func TestValidConfig(t *testing.T) { + if !(&EmailConfig{ + SMTPServer: testServerAddr, + SMTPPort: testServerPort, + FromName: testSenderName, + FromAddress: testSender, + }).Valid() { + t.Error("expected config to be valid") + } + if (&EmailConfig{ + SMTPServer: "", + SMTPPort: testServerPort, + FromName: testSenderName, + FromAddress: testSender, + }).Valid() { + t.Error("expected config to be invalid") + } + if (&EmailConfig{ + SMTPServer: testServerAddr, + SMTPPort: 0, + FromName: testSenderName, + FromAddress: testSender, + }).Valid() { + t.Error("expected config to be invalid") + } + if (&EmailConfig{ + SMTPServer: testServerAddr, + SMTPPort: testServerPort, + FromName: "", + FromAddress: testSender, + }).Valid() { + t.Error("expected config to be invalid") + } + if (&EmailConfig{ + SMTPServer: testServerAddr, + SMTPPort: testServerPort, + FromName: testSenderName, + FromAddress: "", + }).Valid() { + t.Error("expected config to be invalid") + } +} + +func TestNewEmailQueue(t *testing.T) { + // create email queue with valid config + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + eq, err := NewEmailQueue(ctx, &EmailConfig{ + SMTPServer: testServerAddr, + SMTPPort: testServerPort, + FromName: testSenderName, + FromAddress: testSender, + }) + if err != nil { + t.Fatal(err) + } + if eq == nil { + t.Error("expected email queue to be created") + } + // create email queue with auth + eq, err = NewEmailQueue(ctx, &EmailConfig{ + SMTPServer: testServerAddr, + SMTPPort: testServerPort, + FromName: testSenderName, + FromAddress: testSender, + SMTPUsername: "username", + SMTPPassword: "password", + }) + if err != nil { + t.Fatal(err) + } + if eq == nil { + t.Error("expected email queue to be created") + } + // create email queue with invalid config + eq, err = NewEmailQueue(ctx, &EmailConfig{ + SMTPServer: "", + SMTPPort: testServerPort, + FromName: testSenderName, + FromAddress: testSender, + }) + if err == nil { + t.Error("expected error creating email queue") + } + if eq != nil { + t.Error("expected email queue to be nil") + } +} + +func TestSendEmail(t *testing.T) { + // create email queue but don't start it + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + eq, err := NewEmailQueue(ctx, &EmailConfig{ + SMTPServer: testServerAddr, + SMTPPort: testServerPort, + FromName: testSenderName, + FromAddress: testSender, + }) + if err != nil { + t.Fatal(err) + } + // send email + if err := eq.Send(notification.Notification{ + Params: notification.NotificationParams{ + To: testReceiver, + Subject: testSubject, + }, + Body: []byte(testHTMLBody), + PlainBody: []byte(testBody), + }); err != nil { + t.Fatal(err) + } + // check if the email was received + select { + case receivedMsg := <-inboxChan: + if !strings.Contains(receivedMsg, testSubject) { + t.Errorf("expected email content to contain %q, got %q", testSubject, receivedMsg) + } + if !strings.Contains(receivedMsg, testBody) { + t.Errorf("expected email content to contain %q, got %q", testBody, receivedMsg) + } + if !strings.Contains(receivedMsg, testHTMLBody) { + t.Errorf("expected email content to contain %q, got %q", testHTMLBody, receivedMsg) + } + case <-time.After(2 * time.Second): + t.Error("timed out waiting for the email to be received") + } + // try to send invalid email + if err := eq.Send(notification.Notification{}); err == nil { + t.Error("expected error sending invalid email") + } + // try to compose a invalid email + if body, err := eq.composeBody(notification.Notification{}); err == nil { + t.Error("expected error composing invalid email") + } else if body != nil { + t.Error("expected body to be nil") + } + // try to send email to an invalid SMTP server + badEq, err := NewEmailQueue(ctx, &EmailConfig{ + SMTPServer: testServerAddr, + SMTPPort: 8080, + FromName: testSenderName, + FromAddress: testSender, + }) + if err != nil { + t.Fatal(err) + } + if err := badEq.Send(notification.Notification{ + Params: notification.NotificationParams{ + To: testReceiver, + Subject: testSubject, + }, + Body: nil, + PlainBody: []byte(testBody), + }); err == nil { + t.Error("expected error sending email") + } +} + +func TestPushSendEmail(t *testing.T) { + // create email queue and start it + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + errCh := make(chan error, 1) + eq, err := NewEmailQueue(ctx, &EmailConfig{ + SMTPServer: testServerAddr, + SMTPPort: testServerPort, + FromName: testSenderName, + FromAddress: testSender, + ErrorCh: errCh, + }) + if err != nil { + t.Fatal(err) + } + eq.Start() + defer eq.Stop() + // push email + if err := eq.Push(notification.Notification{ + Params: notification.NotificationParams{ + To: testReceiver, + Subject: testSubject, + }, + Body: nil, + PlainBody: []byte(testBody), + }); err != nil { + t.Fatal(err) + } + // check if the email was received + select { + case receivedMsg := <-inboxChan: + if !strings.Contains(receivedMsg, testSubject) { + t.Errorf("expected email content to contain %q, got %q", testSubject, receivedMsg) + } + if !strings.Contains(receivedMsg, testBody) { + t.Errorf("expected email content to contain %q, got %q", testBody, receivedMsg) + } + case <-time.After(2 * time.Second): + t.Error("timed out waiting for the email to be received") + } + // sleep to pop nil email + time.Sleep(2 * time.Second) + // push invalid email + if err := eq.Push(notification.Notification{}); err == nil { + t.Error("expected error pushing invalid email") + } +} diff --git a/notification/email/errors.go b/notification/email/errors.go new file mode 100644 index 0000000..1415164 --- /dev/null +++ b/notification/email/errors.go @@ -0,0 +1,33 @@ +package email + +import "github.com/simpleauthlink/authapi/internal" + +var ( + // ErrInvalidConfig is the error returned when the configuration is invalid. + ErrInvalidConfig = internal.NewErr("invalid configuration") + // ErrInitQueue is the error returned when the queue cannot be initialized. + ErrInitQueue = internal.NewErr("error initializing the queue") + // ErrInvalidEmail is the error returned when the email is invalid. + ErrInvalidEmail = internal.NewErr("invalid email") + // ErrInvalidTemplate is the error returned when the template is invalid. + ErrInvalidTemplate = internal.NewErr("invalid template") + // ErrSendEmail is the error returned when the email cannot be sent. + ErrSendEmail = internal.NewErr("error sending email") + // ErrComposeEmail is the error returned when the email cannot be composed. + ErrComposeEmail = internal.NewErr("error composing email") + // ErrParseAddress is the error returned when the email address cannot + // be parsed. + ErrParseAddress = internal.NewErr("error parsing email address") + // ErrSetBoundary is the error returned when the boundary cannot be set + // when a multipart email is composed. + ErrSetBoundary = internal.NewErr("error setting boundary") + // ErrWriteHTMLBody is the error returned when the email plain body cannot + // be written. + ErrWriteBody = internal.NewErr("error writing email plain body") + // ErrWriteHTMLBody is the error returned when the email HTML body cannot + // be written. + ErrWriteHTMLBody = internal.NewErr("error writing email html body") + // ErrCloseEmailWriter is the error returned when the email writer cannot + // be closed after composing the email. + ErrCloseEmailWriter = internal.NewErr("error closing email writer") +) diff --git a/notification/email/template.go b/notification/email/template.go new file mode 100644 index 0000000..4e9bb7a --- /dev/null +++ b/notification/email/template.go @@ -0,0 +1,92 @@ +package email + +import ( + "bytes" + htmltemplate "html/template" + texttemplate "text/template" + + "github.com/simpleauthlink/authapi/notification" +) + +// EmailTemplate is the definition of an email template, which contains the +// HTML and plain text placeholders to be filled with the data. +type EmailTemplate struct { + HTML string + Plain string +} + +// Compose methods fills the email template with the data and returns the email +// ready to be sent. It returns the email or an error if the template could not +// be filled. It tries to fill both the HTML and plain text templates, but if +// any of them is missing, it will return an error. If some of the placeholders +// in the template are not filled, they will be left as they are. +func (temp *EmailTemplate) Compose(params notification.NotificationParams, data any) (notification.Notification, error) { + if !params.Valid() { + return notification.Notification{}, ErrComposeEmail + } + // compose the html body + body, err := temp.composeHTML(data) + if err != nil { + return notification.Notification{}, err + } + // compose the plain body + plainBody, err := temp.composePlain(data) + if err != nil { + return notification.Notification{}, err + } + // if both bodies are empty, return an error + if plainBody == nil && body == nil { + return notification.Notification{}, ErrInvalidTemplate + } + // return the email with the filled bodies + email := notification.Notification{ + Params: params, + Body: body, + PlainBody: plainBody, + } + return email, nil +} + +// composePlain method fills the plain text template with the data and returns the +// filled content as a byte slice. It returns the filled template or an error +// if the template could not be filled. If the plain text template is empty, it +// returns nil and no error. +func (temp *EmailTemplate) composePlain(data any) ([]byte, error) { + if temp.Plain == "" { + return nil, nil + } + // parse the placeholder plain body template + tmpl, err := texttemplate.New("plain").Parse(temp.Plain) + if err != nil { + return nil, err + } + // inflate the template with the data + buf := new(bytes.Buffer) + if err := tmpl.Execute(buf, data); err != nil { + return nil, err + } + // return the notification with the plain body filled with the data + return buf.Bytes(), nil +} + +// composeHTML method fills the HTML template with the data and returns the filled +// content as a byte slice. It returns the filled template or an error if the +// template could not be filled. If the HTML template is empty, it returns nil +// and no error. +func (temp *EmailTemplate) composeHTML(data any) ([]byte, error) { + if temp.HTML == "" { + return nil, nil + } + // parse the email template + tmpl, err := htmltemplate.New("html").Parse(temp.HTML) + if err != nil { + return nil, err + } + // inflate the template with the data + buf := new(bytes.Buffer) + if err := tmpl.Execute(buf, data); err != nil { + return nil, err + } + // set the body of the notification + return buf.Bytes(), nil +} diff --git a/notification/email/template_test.go b/notification/email/template_test.go new file mode 100644 index 0000000..1d5881d --- /dev/null +++ b/notification/email/template_test.go @@ -0,0 +1,152 @@ +package email + +import ( + "testing" + + "github.com/simpleauthlink/authapi/notification" +) + +var testTemplate = &EmailTemplate{ + HTML: "

{{.Title}}

{{.Content}}

", + Plain: "Title: {{.Title}}\nContent: {{.Content}}", +} + +type testData struct { + Title string + Content string +} + +func TestCompose(t *testing.T) { + // valid data + data := testData{ + Title: "Test Title", + Content: "Test Content", + } + email, err := testTemplate.Compose(notification.NotificationParams{ + To: testReceiver, + Subject: testSubject, + }, data) + if err != nil { + t.Fatalf("expected nil, got error: %v", err) + } + if email.Params.To != testReceiver { + t.Fatalf("got %v, want %v", email.Params.To, testReceiver) + } + if email.Params.Subject != testSubject { + t.Fatalf("got %v, want %v", email.Params.Subject, testSubject) + } + expectedBody := "

Test Title

Test Content

" + expectedPlain := "Title: Test Title\nContent: Test Content" + if string(email.Body) != expectedBody { + t.Fatalf("got %v, want %v", string(email.Body), expectedBody) + } + if string(email.PlainBody) != expectedPlain { + t.Fatalf("got %v, want %v", string(email.PlainBody), expectedPlain) + } + // no subject + if _, err := testTemplate.Compose(notification.NotificationParams{ + To: testReceiver, + Subject: "", + }, data); err == nil { + t.Fatalf("expected error, got nil") + } + // no to address + if _, err := testTemplate.Compose(notification.NotificationParams{ + To: "", + Subject: testSubject, + }, data); err == nil { + t.Fatalf("expected error, got nil") + } + // bad to address + if _, err := testTemplate.Compose(notification.NotificationParams{ + To: "bad email", + Subject: testSubject, + }, data); err == nil { + t.Fatalf("expected error, got nil") + } + // invalid template + emptyTemplate := &EmailTemplate{} + validParams := notification.NotificationParams{ + To: testReceiver, + Subject: testSubject, + } + if _, err := emptyTemplate.Compose(validParams, data); err == nil { + t.Fatalf("expected error, got nil") + } + // no html template + noHTMLTemplate := &EmailTemplate{Plain: testTemplate.Plain} + onlyPlainEmail, err := noHTMLTemplate.Compose(validParams, data) + if err != nil { + t.Fatalf("expected nil, got error: %v", err) + } + if onlyPlainEmail.Body != nil { + t.Fatalf("expected nil, got %v", string(onlyPlainEmail.Body)) + } + if string(onlyPlainEmail.PlainBody) != expectedPlain { + t.Fatalf("got %v, want %v", string(onlyPlainEmail.PlainBody), expectedPlain) + } + // no plain template + noPlainTemplate := &EmailTemplate{HTML: testTemplate.HTML} + onlyHTMLEmail, err := noPlainTemplate.Compose(validParams, data) + if err != nil { + t.Fatalf("expected nil, got error: %v", err) + } + if string(onlyHTMLEmail.Body) != expectedBody { + t.Fatalf("got %v, want %v", string(onlyHTMLEmail.Body), expectedBody) + } + if onlyHTMLEmail.PlainBody != nil { + t.Fatalf("expected nil, got %v", string(onlyHTMLEmail.PlainBody)) + } +} + +func Test_composePlain(t *testing.T) { + // valid data and template + data := testData{ + Title: "Test Title", + Content: "Test Content", + } + body, err := testTemplate.composePlain(data) + if err != nil { + t.Fatalf("expected nil, got error: %v", err) + } + expected := "Title: Test Title\nContent: Test Content" + if string(body) != expected { + t.Fatalf("got %v, want %v", string(body), expected) + } + // no plain template + wrongPlainTemplate := *testTemplate + wrongPlainTemplate.Plain = "" + body, err = wrongPlainTemplate.composePlain(data) + if err != nil { + t.Fatalf("expected nil, got error: %v", err) + } + if body != nil { + t.Fatalf("expected nil, got %v", string(body)) + } +} + +func Test_composeHTML(t *testing.T) { + // valid data and template + data := testData{ + Title: "Test Title", + Content: "Test Content", + } + body, err := testTemplate.composeHTML(data) + if err != nil { + t.Fatalf("expected nil, got error: %v", err) + } + expected := "

Test Title

Test Content

" + if string(body) != expected { + t.Fatalf("got %v, want %v", string(body), expected) + } + // no html template + wrongHTMLTemplate := *testTemplate + wrongHTMLTemplate.HTML = "" + body, err = wrongHTMLTemplate.composeHTML(data) + if err != nil { + t.Fatalf("expected nil, got error: %v", err) + } + if body != nil { + t.Fatalf("expected nil, got %v", string(body)) + } +} diff --git a/notification/notification.go b/notification/notification.go new file mode 100644 index 0000000..308c79a --- /dev/null +++ b/notification/notification.go @@ -0,0 +1,33 @@ +package notification + +import "net/mail" + +type NotificationParams struct { + To string + Subject string +} + +func (p NotificationParams) Valid() bool { + _, err := mail.ParseAddress(p.To) + return err == nil && p.Subject != "" +} + +type Notification struct { + Params NotificationParams + Body []byte + PlainBody []byte +} + +// Valid method checks if the email is valid. It returns true if the recipient +// email address, the subject and the body are not empty. +func (n *Notification) Valid() bool { + return n.Params.Valid() && max(len(n.Body), len(n.PlainBody)) > 0 +} + +type Queue interface { + Start() + Stop() + Pop() (Notification, bool) + Push(Notification) error + Send(Notification) error +} diff --git a/notification/notification_test.go b/notification/notification_test.go new file mode 100644 index 0000000..573947c --- /dev/null +++ b/notification/notification_test.go @@ -0,0 +1,90 @@ +package notification + +import ( + "testing" +) + +func TestNotificationParams_Valid(t *testing.T) { + tests := []struct { + name string + params NotificationParams + valid bool + }{ + { + name: "Valid params", + params: NotificationParams{To: "test@example.com", Subject: "Test Subject"}, + valid: true, + }, + { + name: "Invalid email", + params: NotificationParams{To: "invalid-email", Subject: "Test Subject"}, + valid: false, + }, + { + name: "Empty subject", + params: NotificationParams{To: "test@example.com", Subject: ""}, + valid: false, + }, + { + name: "Empty email and subject", + params: NotificationParams{To: "", Subject: ""}, + valid: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.params.Valid(); got != tt.valid { + t.Errorf("expected valid: %v, got: %v", tt.valid, got) + } + }) + } +} + +func TestNotification_Valid(t *testing.T) { + tests := []struct { + name string + notification Notification + valid bool + }{ + { + name: "Valid notification with Body", + notification: Notification{ + Params: NotificationParams{To: "test@example.com", Subject: "Test Subject"}, + Body: []byte("Test Body"), + }, + valid: true, + }, + { + name: "Valid notification with PlainBody", + notification: Notification{ + Params: NotificationParams{To: "test@example.com", Subject: "Test Subject"}, + PlainBody: []byte("Test Plain Body"), + }, + valid: true, + }, + { + name: "Invalid notification with empty Body and PlainBody", + notification: Notification{ + Params: NotificationParams{To: "test@example.com", Subject: "Test Subject"}, + }, + valid: false, + }, + { + name: "Invalid notification with invalid Params", + notification: Notification{ + Params: NotificationParams{To: "invalid-email", Subject: "Test Subject"}, + Body: []byte("Test Body"), + }, + valid: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.notification.Valid(); got != tt.valid { + t.Errorf("expected valid: %v, got: %v", tt.valid, got) + } + }) + } +} diff --git a/notification/templates/login/definition.go b/notification/templates/login/definition.go new file mode 100644 index 0000000..9a66b67 --- /dev/null +++ b/notification/templates/login/definition.go @@ -0,0 +1,64 @@ +package login + +import ( + _ "embed" + "regexp" + + "github.com/simpleauthlink/authapi/notification" + "github.com/simpleauthlink/authapi/notification/email" + "github.com/simpleauthlink/authapi/token" +) + +//go:embed template.html +var htmlTemplate string + +// Data struct contains the required data to fill the login email template. +type Data struct { + AppName string + Email string + Token string + Link string +} + +// Subject returns the email subject based on the login data. +func (d Data) Subject() string { + return "Your token for '" + d.AppName + "'" +} + +// FindToken function extracts the token from the email content. It uses a +// regular expression to fill the template with regex and find the token in +// the email content. Then it decodes the token and returns it. If the token +// is not found, it returns nil. +func FindToken(email, content string) *token.Token { + loginData := Data{ + AppName: `.+`, + Email: email, + Token: `(.+\..+)`, + Link: `.+`, + } + loginEmail, err := Template.Compose(notification.NotificationParams{ + To: email, + Subject: loginData.Subject(), + }, loginData) + if err != nil { + return nil + } + tokenRgx := regexp.MustCompile(string(loginEmail.PlainBody)) + tokenResult := tokenRgx.FindAllStringSubmatch(content, -1) + if len(tokenResult) < 1 || len(tokenResult[0]) < 2 { + return nil + } + return new(token.Token).SetString(tokenResult[0][1]) +} + +// Template is the login email template definition, which contains the HTML +// and plain text templates. +var Template = email.EmailTemplate{ + HTML: htmlTemplate, + Plain: `Hi, {{.Email}} +You can access to '{{.AppName}}' app using the following link: +{{.Link}} +It contains your login token: '{{.Token}}' +Which is only valid for you and for a short period of time. +If you didn't request this, you can ignore this email.`, +} diff --git a/assets/token_email_template.html b/notification/templates/login/template.html similarity index 60% rename from assets/token_email_template.html rename to notification/templates/login/template.html index 2a51785..ec1476c 100644 --- a/assets/token_email_template.html +++ b/notification/templates/login/template.html @@ -1,5 +1,4 @@ - @@ -7,21 +6,13 @@ Your Magic Link for {{.AppName}} Login - + - -
@@ -29,35 +20,41 @@ style="border-collapse: collapse; border: 1px solid #cccccc;">
- + Logo

SimpleAuth.link

- πŸ‘‹ Hi, {{.EmailHandler}}! + +

πŸ‘‹ Hi, {{.Email}}!



- Your magic link to login to '{{.AppName}}' is ready πŸŽ‰. + Your magic link is ready! πŸŽ‰

+ You can access to your {{.AppName}} account using it πŸ”. +
Click the button below to login to your account. πŸ‘‡
- + -
- Login to + Login to {{.AppName}}
+ +
+ It contains your login token: +
+
{{.Token}}
+
+ Which is only valid for you and for a short period of time.
- {{.Token}} + If you didn't request this, you can ignore this email.
@@ -67,16 +64,14 @@

SimpleAuth or copy and paste the following link in your browser:

-
{{.MagicLink}}
+
{{.Link}}


If you did not request this, please ignore this email.

diff --git a/token/app.go b/token/app.go new file mode 100644 index 0000000..a952a14 --- /dev/null +++ b/token/app.go @@ -0,0 +1,174 @@ +package token + +import ( + "bytes" + "encoding/base64" + "encoding/hex" + "strings" + "time" +) + +// App represents an application that can request tokens. It has a name, a +// redirect URI, and a session duration. +type App struct { + Name string + RedirectURI string + SessionDuration time.Duration + AppSecretHash []byte +} + +// Valid method returns true if the app is valid, false otherwise. An app is +// considered valid if its name is between 3 and 20 characters, its redirect +// URI is a valid URI, and its session duration is between 5 minutes and 24 +// hours. +func (app *App) Valid(secretHash []byte) bool { + if app == nil { + return false + } + // check if the app name is between the min and max length + if len(app.Name) < appNameMinLen || len(app.Name) > appNameMaxLen { + return false + } + // check if the redirect URI is valid + if !uriRegexp.MatchString(app.RedirectURI) || len(app.RedirectURI) > redirectURIMaxLen { + return false + } + // check if the session duration is between the min and max duration + if app.SessionDuration < minDuration || app.SessionDuration > maxDuration { + return false + } + if secretHash != nil { + return bytes.Equal(app.AppSecretHash, secretHash) + } + return true +} + +// Attributes method returns the app's attributes as a slice of strings. This +// is useful for encoding the app. +func (app *App) Attributes() []string { + return []string{app.Name, app.RedirectURI, app.SessionDuration.String(), hex.EncodeToString(app.AppSecretHash)} +} + +// SetAttributes method sets the app's attributes from a slice of strings. This +// is useful for decoding the app. +func (app *App) SetAttributes(attrs []string) *App { + // check if the slice has the correct number of attributes + if len(attrs) != 4 { + return nil + } + // parse the session duration + duration, err := time.ParseDuration(attrs[2]) + if err != nil { + return nil + } + // if the app is nil, create a new app + if app == nil { + app = new(App) + } + // set the app's attributes + app.Name = attrs[0] + app.RedirectURI = attrs[1] + app.SessionDuration = duration + appSecretHash, err := hex.DecodeString(attrs[3]) + if err != nil { + return nil + } + if len(appSecretHash) != secretHashSize { + return nil + } + app.AppSecretHash = appSecretHash + // check if the app is valid and return it if it is + if !app.Valid(nil) { + return nil + } + return app +} + +// String method returns the app as a string. This is useful for debugging +// and encoding the app. The resulting string is the app's attributes joined +// by the app data separator. +func (app *App) String() string { + if !app.Valid(nil) { + return "" + } + // join the app's attributes with the app data separator + return strings.Join(app.Attributes(), appDataSeparator) +} + +// SetString method sets the app from a string. This is useful for decoding +// the app. The string should be the app's attributes joined by the app data +// separator. +func (app *App) SetString(data string) *App { + b := strings.Split(data, appDataSeparator) + return app.SetAttributes(b) +} + +// Bytes method returns the app as a byte slice. This is useful for encoding +// the app. It is equivalent to converting the app to a string and then +// converting the string to a byte slice. +func (app *App) Bytes() []byte { + return []byte(app.String()) +} + +// SetBytes method sets the app from a byte slice. This is useful for decoding +// the app. It is equivalent to converting the byte slice to a string and then +// converting the string to the app. +func (app *App) SetBytes(data []byte) *App { + return app.SetString(string(data)) +} + +// Marshal method returns the app as a base64-encoded byte slice. It is used +// to be included in the app ID, which makes it self-contained. +func (app *App) Marshal() []byte { + if !app.Valid(nil) { + return nil + } + bApp := app.Bytes() + b := make([]byte, base64.RawStdEncoding.EncodedLen(len(bApp))) + base64.RawStdEncoding.Encode(b, bApp) + return b +} + +// Unmarshal method sets the app from a base64-encoded byte slice. It is used +// to extract the app from the app ID. +func (app *App) Unmarshal(data []byte) *App { + b := make([]byte, base64.RawStdEncoding.DecodedLen(len(data))) + if _, err := base64.RawStdEncoding.Decode(b, data); err != nil { + return nil + } + return app.SetBytes(b) +} + +// ID method returns the app ID of the app. The app ID is a self-contained +// representation of the app that can be used to generate tokens. It is +// created by encoding the app as a base64-encoded byte slice using the +// Marshal method. +func (app *App) ID(secret *Secret) *AppID { + if !app.Valid(secret.Hash()) { + return nil + } + return new(AppID).SetBytes(app.Marshal()) +} + +// SetID method sets the app from an app ID. The app ID is a self-contained +// representation of the app that can be used to generate tokens. The app is +// extracted from the app ID by decoding the app as a base64-encoded byte +// slice using the Unmarshal method. +func (app *App) SetID(id *AppID) *App { + if id == nil { + return nil + } + return app.Unmarshal(id.Bytes()) +} + +func (app *App) SetSecret(secret *Secret) *App { + if app == nil { + return nil + } + if secret == nil { + return app + } + // set the app secret hash + app.AppSecretHash = secret.Hash() + return app +} diff --git a/token/app_test.go b/token/app_test.go new file mode 100644 index 0000000..a68acd0 --- /dev/null +++ b/token/app_test.go @@ -0,0 +1,269 @@ +package token + +import ( + "bytes" + "encoding/hex" + "testing" + "time" +) + +const ( + testAppName = "MySuperMegaApp" + testRedirectURI = "https://example.com/login?app=MySuperMegaApp" + testSessionDuration = time.Minute * 30 +) + +var testAppSecret = new(Secret).SetParts([]byte("super_secret_key"), []byte("super_secret_salt")) + +func TestValidApp(t *testing.T) { + app := &App{ + Name: testAppName, + RedirectURI: testRedirectURI, + SessionDuration: testSessionDuration, + } + if !app.Valid(nil) { + t.Errorf("expected valid app data") + } + // test app name + app.Name = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." + if app.Valid(nil) { + t.Errorf("expected invalid app data") + } + app.Name = "no" + if app.Valid(nil) { + t.Errorf("expected invalid app data") + } + app.Name = testAppName + // test redirect URI + app.RedirectURI = "https://example.com/login?app=lorem_ipsum_dolor_sit_amet_consectetur_adipiscing_elit_sed_do_eiusmod_tempor_incididunt_ut_labore_et_dolore_magna_aliqua" + if app.Valid(nil) { + t.Errorf("expected invalid app data") + } + app.RedirectURI = "no_url" + if app.Valid(nil) { + t.Errorf("expected invalid app data") + } + app.RedirectURI = testRedirectURI + // test session duration + app.SessionDuration = minDuration - 1 + if app.Valid(nil) { + t.Errorf("expected invalid app data") + } + app.SessionDuration = maxDuration + 1 + if app.Valid(nil) { + t.Errorf("expected invalid app data") + } + var nilApp *App + if nilApp.Valid(nil) { + t.Errorf("expected invalid app data") + } + app.AppSecretHash = testAppSecret.Hash() + servicePart := []byte("invalid-service-secret") + appPart := []byte("invalid-app-secret") + invalidSecret := new(Secret).SetParts(servicePart, appPart) + if app.Valid(invalidSecret.Hash()) { + t.Errorf("expected invalid app data") + } +} + +func TestAttributesSetAttributesApp(t *testing.T) { + if res := new(App).SetAttributes([]string{}); res != nil { + t.Errorf("expected nil, got %v", res) + } + strHash := hex.EncodeToString(testAppSecret.Hash()) + if res := new(App).SetAttributes([]string{testAppName, testRedirectURI, "no_duration", strHash}); res != nil { + t.Errorf("expected nil, got %v", res) + } + if res := new(App).SetAttributes([]string{testAppName, "no_url", testSessionDuration.String()}); res != nil { + t.Errorf("expected nil, got %v", res) + } + if res := new(App).SetAttributes([]string{testAppName, testRedirectURI, testSessionDuration.String(), "invalid_hash"}); res != nil { + t.Errorf("expected nil, got %v", res) + } + if res := new(App).SetAttributes([]string{testAppName, testRedirectURI, testSessionDuration.String(), "2bbc94cb9c916e1f6f1354ef30c1c80767b85159570304baa402c088180a0ec5"}); res != nil { + t.Errorf("expected nil, got %v", res) + } + app := &App{ + Name: testAppName, + RedirectURI: testRedirectURI, + SessionDuration: testSessionDuration, + } + servicePart := []byte("service-secret") + appPart := []byte("app-secret") + secret := new(Secret).SetParts(servicePart, appPart) + app.SetSecret(secret) + var nilApp *App + nilData := nilApp.SetAttributes(app.Attributes()) + if nilData == nil { + t.Fatalf("error decoding app data") + } + if nilData.Name != testAppName { + t.Errorf("expected app name %q, got %q", testAppName, nilData.Name) + } + if nilData.RedirectURI != testRedirectURI { + t.Errorf("expected redirect URI %q, got %q", testRedirectURI, nilData.RedirectURI) + } + if nilData.SessionDuration != testSessionDuration { + t.Errorf("expected session duration %v, got %v", testSessionDuration, nilData.SessionDuration) + } + data := new(App).SetAttributes(app.Attributes()) + if data == nil { + t.Fatalf("error decoding app data") + } + if data.Name != testAppName { + t.Errorf("expected app name %q, got %q", testAppName, data.Name) + } + if data.RedirectURI != testRedirectURI { + t.Errorf("expected redirect URI %q, got %q", testRedirectURI, data.RedirectURI) + } + if data.SessionDuration != testSessionDuration { + t.Errorf("expected session duration %v, got %v", testSessionDuration, data.SessionDuration) + } + // set an out of range duration to test if the app is valid + attrs := app.Attributes() + attrs[2] = "1s" + if res := new(App).SetAttributes(attrs); res != nil { + t.Errorf("expected nil, got %v", res) + } +} + +func TestStringSetStringApp(t *testing.T) { + if res := new(App).String(); res != "" { + t.Errorf("expected empty string, got %q", res) + } + app := &App{ + Name: testAppName, + RedirectURI: testRedirectURI, + SessionDuration: testSessionDuration, + } + servicePart := []byte("service-secret") + appPart := []byte("app-secret") + secret := new(Secret).SetParts(servicePart, appPart) + app.SetSecret(secret) + data := new(App).SetString(app.String()) + if data == nil { + t.Fatalf("error decoding app data") + } + if data.Name != testAppName { + t.Errorf("expected app name %q, got %q", testAppName, data.Name) + } + if data.RedirectURI != testRedirectURI { + t.Errorf("expected redirect URI %q, got %q", testRedirectURI, data.RedirectURI) + } + if data.SessionDuration != testSessionDuration { + t.Errorf("expected session duration %v, got %v", testSessionDuration, data.SessionDuration) + } +} + +func TestBytesSetBytesApp(t *testing.T) { + app := &App{ + Name: testAppName, + RedirectURI: testRedirectURI, + SessionDuration: testSessionDuration, + } + servicePart := []byte("service-secret") + appPart := []byte("app-secret") + secret := new(Secret).SetParts(servicePart, appPart) + app.SetSecret(secret) + data := new(App).SetBytes(app.Bytes()) + if data == nil { + t.Fatalf("error decoding app data") + } + if data.Name != testAppName { + t.Errorf("expected app name %q, got %q", testAppName, data.Name) + } + if data.RedirectURI != testRedirectURI { + t.Errorf("expected redirect URI %q, got %q", testRedirectURI, data.RedirectURI) + } + if data.SessionDuration != testSessionDuration { + t.Errorf("expected session duration %v, got %v", testSessionDuration, data.SessionDuration) + } +} + +func TestMarshalUnmarshalApp(t *testing.T) { + if res := new(App).Marshal(); res != nil { + t.Errorf("expected nil, got %v", res) + } + if res := new(App).Unmarshal([]byte{1}); res != nil { + t.Errorf("expected nil, got %v", res) + } + app := &App{ + Name: testAppName, + RedirectURI: testRedirectURI, + SessionDuration: testSessionDuration, + } + servicePart := []byte("service-secret") + appPart := []byte("app-secret") + secret := new(Secret).SetParts(servicePart, appPart) + app.SetSecret(secret) + data := new(App).Unmarshal(app.Marshal()) + if data == nil { + t.Fatalf("error decoding app data") + } + if data.Name != testAppName { + t.Errorf("expected app name %q, got %q", testAppName, data.Name) + } + if data.RedirectURI != testRedirectURI { + t.Errorf("expected redirect URI %q, got %q", testRedirectURI, data.RedirectURI) + } + if data.SessionDuration != testSessionDuration { + t.Errorf("expected session duration %v, got %v", testSessionDuration, data.SessionDuration) + } +} + +func TestAppID(t *testing.T) { + servicePart := []byte("service-secret") + appPart := []byte("app-secret") + secret := new(Secret).SetParts(servicePart, appPart) + if id := new(App).ID(secret); id != nil { + t.Errorf("expected nil, got %v", id) + } + app := &App{ + Name: testAppName, + RedirectURI: testRedirectURI, + SessionDuration: testSessionDuration, + } + app.SetSecret(secret) + id := app.ID(secret) + if id == nil { + t.Fatalf("error decoding app ID") + } + if !bytes.Equal(id.Bytes(), app.Marshal()) { + t.Errorf("expected %v, got %v", app.Marshal(), id.Bytes()) + } + if res := new(App).SetID(nil); res != nil { + t.Errorf("expected nil, got %v", res) + } + newApp := new(App).SetID(id) + if newApp == nil { + t.Fatalf("error decoding app ID") + } + if newApp.String() != app.String() { + t.Errorf("expected %s, got %s", app.String(), newApp.String()) + } +} + +func TestSetSecretApp(t *testing.T) { + var nilApp *App + if res := nilApp.SetSecret(nil); res != nil { + t.Errorf("expected nil, got %v", res) + } + app := &App{ + Name: testAppName, + RedirectURI: testRedirectURI, + SessionDuration: testSessionDuration, + } + if res := app.SetSecret(nil); res == nil { + t.Errorf("expected nil, got %v", res) + } + if app.AppSecretHash != nil { + t.Errorf("expected nil, got %v", app.AppSecretHash) + } + servicePart := []byte("service-secret") + appPart := []byte("app-secret") + secret := new(Secret).SetParts(servicePart, appPart) + app.SetSecret(secret) + if !bytes.Equal(app.AppSecretHash, secret.Hash()) { + t.Errorf("expected %v, got %v", secret.Hash(), app.AppSecretHash) + } +} diff --git a/token/consts.go b/token/consts.go new file mode 100644 index 0000000..0de326b --- /dev/null +++ b/token/consts.go @@ -0,0 +1,19 @@ +package token + +import ( + "regexp" + "time" +) + +const ( + appDataSeparator = "|" + appNameMinLen = 3 + appNameMaxLen = 20 + redirectURIPattern = `^https?://(?:localhost|[a-zA-Z0-9-]+(?:\.[a-zA-Z0-9-]+)+)(?::\d+)?(/[a-zA-Z0-9-._~:/?#[\]@!$&'()*+,;=]*)?$` + redirectURIMaxLen = 80 + minDuration = 30 * time.Second + maxDuration = 180 * 24 * time.Hour + tokenSeparator = '.' +) + +var uriRegexp = regexp.MustCompile(redirectURIPattern) diff --git a/token/errors.go b/token/errors.go new file mode 100644 index 0000000..287a594 --- /dev/null +++ b/token/errors.go @@ -0,0 +1,10 @@ +package token + +import "fmt" + +var ( + ErrInvalidAppID = fmt.Errorf("invalid app ID") + ErrInvalidAppName = fmt.Errorf("invalid app name") + ErrInvalidRedirectURI = fmt.Errorf("invalid redirect URI") + ErrInvalidSessionDuration = fmt.Errorf("invalid session duration") +) diff --git a/token/expiration.go b/token/expiration.go new file mode 100644 index 0000000..31e7f40 --- /dev/null +++ b/token/expiration.go @@ -0,0 +1,134 @@ +package token + +import ( + "encoding/base64" + "time" +) + +// Expiration represents a time when a token expires. It is a wrapper around +// time.Time that provides additional methods for setting and getting the +// expiration time. +type Expiration time.Time + +// Valid method returns true if the expiration is valid, false otherwise. An +// expiration is considered valid if it is in the future. +func (exp *Expiration) Valid() bool { + return time.Now().Before(exp.Time()) +} + +// Time method returns the expiration time as a time.Time. +func (exp *Expiration) Time() time.Time { + return time.Time(*exp) +} + +// SetTime method sets the expiration time from a time.Time. If the expiration +// is nil, a new expiration is created. +func (exp *Expiration) SetTime(t time.Time) *Expiration { + // if no expiration is provided, initialize a new one + if exp == nil { + exp = new(Expiration) + } + // set the expiration time + *exp = Expiration(t) + // if the expiration is invalid, return nil + if !exp.Valid() { + return nil + } + return exp +} + +// Duration method returns the duration until the expiration time. +func (exp *Expiration) Duration() time.Duration { + return time.Until(exp.Time()) +} + +// SetDuration method sets the expiration time from a duration. If the duration +// is invalid, the expiration time is not set. +func (exp *Expiration) SetDuration(d time.Duration) *Expiration { + if d < minDuration || d > maxDuration { + return nil + } + if exp == nil { + exp = new(Expiration) + } + return exp.SetTime(time.Now().Add(d)) +} + +// String method returns the expiration time as a string in RFC3339Nano format. +// It is useful for encoding the expiration time. If the expiration is nil, an +// empty string is returned. +func (exp *Expiration) String() string { + if exp == nil { + return "" + } + if !exp.Valid() { + return "" + } + return exp.Time().Format(time.RFC3339Nano) +} + +// SetString method sets the expiration time from a string in RFC3339Nano +// format. It is useful for decoding the expiration time. If the string +// is invalid, the expiration time is not set and nil is returned. If the +// expiration is nil, a new expiration is created. If the resulting expiration +// is invalid, nil is returned. +func (exp *Expiration) SetString(data string) *Expiration { + // parse the expiration time + t, err := time.Parse(time.RFC3339Nano, data) + if err != nil { + return nil + } + // if the expiration is nil, initialize a new one + if exp == nil { + exp = new(Expiration) + } + // set the expiration time + *exp = Expiration(t) + // if the expiration is invalid, return nil + if !exp.Valid() { + return nil + } + return exp +} + +// Bytes method returns the expiration time as a byte slice. It is useful for +// encoding the expiration time. If the expiration is nil, nil is returned. It +// is equivalent to converting the expiration time to a string and then +// converting the string to a byte slice. +func (exp *Expiration) Bytes() []byte { + if exp.String() == "" { + return nil + } + return []byte(exp.String()) +} + +// SetBytes method sets the expiration time from a byte slice. It is useful for +// decoding the expiration time. It is equivalent to converting the byte slice +// to a string and then setting the expiration time from the string. +func (exp *Expiration) SetBytes(data []byte) *Expiration { + return exp.SetString(string(data)) +} + +// Marshal method returns the expiration time as a base64 encoded byte slice. It +// is useful for encoding the expiration time. If the expiration is nil or +// invalid, nil is returned. +func (exp *Expiration) Marshal() []byte { + bExp := exp.Bytes() + if len(bExp) == 0 || bExp[0] == 0 { + return nil + } + b := make([]byte, base64.RawStdEncoding.EncodedLen(len(bExp))) + base64.RawStdEncoding.Encode(b, bExp) + return b +} + +// Unmarshal method sets the expiration time from a base64 encoded byte slice. It +// is useful for decoding the expiration time. If the expiration is nil or +// invalid, nil is returned. +func (exp *Expiration) Unmarshal(data []byte) *Expiration { + b := make([]byte, base64.RawStdEncoding.DecodedLen(len(data))) + if _, err := base64.RawStdEncoding.Decode(b, data); err != nil { + return nil + } + return exp.SetBytes(b) +} diff --git a/token/expiration_test.go b/token/expiration_test.go new file mode 100644 index 0000000..f7b2525 --- /dev/null +++ b/token/expiration_test.go @@ -0,0 +1,123 @@ +package token + +import ( + "bytes" + "testing" + "time" +) + +func TestValidExpiration(t *testing.T) { + t.Parallel() + exp := new(Expiration).SetDuration(minDuration + time.Second) + if !exp.Valid() { + t.Errorf("expected valid expiration, got invalid") + } + time.Sleep(minDuration + (time.Second * 2)) + if exp.Valid() { + t.Errorf("expected invalid expiration, got valid") + } +} + +func TestTimeSetTimeExpiration(t *testing.T) { + var nilExp *Expiration + exp := nilExp.SetTime(time.Now().Add(minDuration * 2)) + if exp == nil { + t.Fatalf("expected valid expiration, got nil") + } + exp = new(Expiration).SetTime(time.Now().Add(minDuration * 2)) + if exp == nil { + t.Fatalf("expected valid expiration, got nil") + } + expTime := exp.Time() + expected := time.Now().Add(minDuration * 2) + if expected.Sub(expTime) > time.Millisecond*300 { + t.Errorf("expected %v, got %v", expected, expTime) + } + invalidTime := time.Now().Add(-time.Second) + if exp := new(Expiration).SetTime(invalidTime); exp != nil { + t.Errorf("expected nil, got %v", exp) + } +} + +func TestDurationSetDurationExpiration(t *testing.T) { + exp := new(Expiration).SetDuration(minDuration - 1) + if exp != nil { + t.Fatalf("expected nil, got %v", exp) + } + exp = exp.SetDuration(minDuration * 2) + if exp == nil { + t.Fatalf("expected valid expiration, got nil") + } + expectedDuration := time.Duration(minDuration * 2) + if expectedDuration-exp.Duration() > time.Millisecond*300 { + t.Errorf("expected %v, got %v", expectedDuration, exp.Duration()) + } +} + +func TestStringSetStringExpiration(t *testing.T) { + exp := new(Expiration).SetDuration(minDuration * 2) + str := exp.String() + decoded := new(Expiration).SetString(str) + if decoded == nil { + t.Fatalf("expected valid expiration, got nil") + } + if exp.String() != decoded.String() { + t.Errorf("expected %v, got %v", exp, decoded) + } + if exp := new(Expiration).SetString("invalid"); exp != nil { + t.Errorf("expected nil, got %v", exp) + } + if exp := new(Expiration).String(); exp != "" { + t.Errorf("expected empty string, got %v", exp) + } + var nilExp *Expiration + if nilExp.String() != "" { + t.Errorf("expected empty string, got %v", nilExp.String()) + } + if nilExp = nilExp.SetString(str); nilExp == nil { + t.Fatalf("expected valid expiration, got nil") + } + if nilExp.String() != str { + t.Errorf("expected %v, got %v", str, nilExp.String()) + } + invalidTime := time.Now().Add(-time.Second).Format(time.RFC3339Nano) + if exp := new(Expiration).SetString(invalidTime); exp != nil { + t.Errorf("expected nil, got %v", exp) + } +} + +func TestBytesSetBytesExpiration(t *testing.T) { + exp := new(Expiration).SetDuration(minDuration * 2) + b := exp.Bytes() + decoded := new(Expiration).SetBytes(b) + if decoded == nil { + t.Fatalf("expected valid expiration, got nil") + } + if !bytes.Equal(b, decoded.Bytes()) { + t.Errorf("expected %v, got %v", exp, decoded) + } + if exp := new(Expiration).SetBytes([]byte("invalid")); exp != nil { + t.Errorf("expected nil, got %v", exp) + } + if exp := new(Expiration).Bytes(); exp != nil { + t.Errorf("expected nil, got %v", exp) + } +} + +func TestMarshalUnmarshalExpiration(t *testing.T) { + exp := new(Expiration).SetDuration(minDuration * 2) + encoded := exp.Marshal() + decoded := new(Expiration).Unmarshal(encoded) + if decoded == nil { + t.Fatalf("expected valid expiration, got nil") + } + if exp.String() != decoded.String() { + t.Errorf("expected %v, got %v", exp, decoded) + } + if res := new(Expiration).Marshal(); res != nil { + t.Errorf("expected nil, got %v", res) + } + if res := new(Expiration).Unmarshal([]byte{1}); res != nil { + t.Errorf("expected nil, got %v", res) + } +} diff --git a/token/id.go b/token/id.go new file mode 100644 index 0000000..e7f0699 --- /dev/null +++ b/token/id.go @@ -0,0 +1,232 @@ +package token + +import ( + "crypto/ed25519" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" +) + +// AppID represents an application ID that is used to generate and verify +// tokens. It is a wrapper around a byte slice that provides additional +// methods for setting and getting the application ID. +type AppID []byte + +// String method returns the application ID as a string. If the application +// ID is nil, an empty string is returned. It internally calls the Bytes +// method to get the application ID as a byte slice and converts it to a +// string. +func (id *AppID) String() string { + return string(id.Bytes()) +} + +// SetString method sets the application ID from a string. If the application +// ID is nil, a new application ID is created. If the string is empty, the +// application ID is not set. It internally calls the SetBytes method to set +// the application ID from a byte slice. +func (id *AppID) SetString(data string) *AppID { + return id.SetBytes([]byte(data)) +} + +// Bytes method returns the application ID as a byte slice. +func (id *AppID) Bytes() []byte { + if id == nil { + return nil + } + return []byte(*id) +} + +// SetBytes method sets the application ID from a byte slice. If the +// application ID is nil, a new application ID is created. If the byte slice +// is empty, the application ID is not set. It internally calls the Unmarshal +// method of the App to check that the application ID is valid before setting +// the application ID from the byte slice. If the resulting application is not +// valid, the application ID is not set and nil is returned. +func (id *AppID) SetBytes(data []byte) *AppID { + // check if the application ID is valid + if !new(App).Unmarshal(data).Valid(nil) { + return nil + } + // if no application ID is provided, create a new one + if id == nil { + id = new(AppID) + } + // set the application ID + *id = data + return id +} + +// PrivKey method returns the private key for the application ID. If the +// application ID is nil or the secret is invalid, nil is returned. It +// internally calls the Bytes method to get the application ID as a byte +// slice and uses it to generate the private key. The private key is +// generated by hashing the application ID with the secret and using the +// resulting hash as the seed for an ed25519 private key. +func (id *AppID) PrivKey(secret Secret) ed25519.PrivateKey { + if !secret.Valid() { + return nil + } + if id == nil { + return nil + } + bID := id.Bytes() + if len(bID) == 0 { + return nil + } + hFn := hmac.New(sha256.New, secret.Bytes()) + hID := hFn.Sum(bID) + return ed25519.NewKeyFromSeed(hID[:32]) +} + +// Sign method returns the signature of the message for the application ID. +// If the application ID is nil, the message is empty, or the secret is +// invalid, nil is returned. It internally calls the PrivKey method to get +// the private key for the application ID and uses it to sign the message. +// The message is signed by appending a nonce to it and hashing the result +// with the private key. The signature is then encoded to base64 before +// being returned to be used as a part of the token, keeping it as short as +// possible. +func (id *AppID) Sign(secret Secret, msg []byte) []byte { + // check if the application ID is valid or the message is empty + if id == nil || len(msg) == 0 { + return nil + } + // get the private key for the application ID and the secret + privKey := id.PrivKey(secret) + if len(privKey) == 0 { + return nil + } + // append the message with a nonce and hash it with the private key + data := append(msg, hedgedNonce(privKey[:], msg)...) + // sign the data with the private key + rawSign := ed25519.Sign(privKey, data[:]) + // encode the signature to base64 and return it + sign := make([]byte, base64.RawStdEncoding.EncodedLen(len(rawSign))) + base64.RawStdEncoding.Encode(sign, rawSign) + return sign +} + +// Verify method returns true if the signature of the message is valid for +// the application ID. If the application ID is nil, the message is empty, +// the signature is empty, or the secret is invalid, false is returned. It +// internally calls the PrivKey method to get the private key for the +// application ID and uses it to verify the signature. The message is +// verified by appending a nonce to it and hashing the result with the +// public key. The signature is then decoded from base64 and verified with +// the public key to ensure that it was signed by the private key. +func (id *AppID) Verify(secret Secret, msg, sig []byte) bool { + // check if the application ID is valid or the message and signature are + // not empty + if id == nil || len(msg) == 0 || len(sig) == 0 { + return false + } + // get the private key for the application ID and the secret + privKey := id.PrivKey(secret) + if privKey == nil { + return false + } + // decode sign from base64 + rawSign := make([]byte, base64.RawStdEncoding.DecodedLen(len(sig))) + if _, err := base64.RawStdEncoding.Decode(rawSign, sig); err != nil { + return false + } + // recover the data with the nonce and the message + data := append(msg, hedgedNonce(privKey[:], msg)...) + // verify the data with the public key + pubKey := privKey.Public().(ed25519.PublicKey) + return ed25519.Verify(pubKey, data, rawSign) +} + +// GenerateToken method returns a token for the application ID. If the app ID +// is nil, the secret is invalid, or the email is empty, nil is returned. It +// internally calls the Message method to generate the message for the token +// and uses it to sign the message. The token is generated by hashing the +// application ID with the email and expiration time, and signing the result +// with the secret. The signature is then used to create a token with the +// expiration time and signature. +func (id *AppID) GenerateToken(secret Secret, email string) Token { + // check if the application ID is valid + if id == nil { + return nil + } + // get the application for the application ID + app := new(App).SetID(id) + if app == nil { + return nil + } + // calculate the expiration time for the current app + exp := new(Expiration).SetDuration(app.SessionDuration) + // get the message to sign + msg := id.Message(email, *exp) + // sign the message with the secret + sig := id.Sign(secret, msg) + // check if the signature is valid + if len(sig) == 0 { + return nil + } + // create a new token with the expiration time and signature + return *new(Token).SetExpiration(*exp).SetSignature(sig) +} + +// Message method returns the message for the application ID. If the app ID +// is nil, the email is empty, or the expiration time is invalid, nil is +// returned. It is used to generate the message for the token by hashing the +// application ID with the email and expiration time. The message is generated +// by appending the application ID with the email and expiration time, and +// hashing the result with sha256 to create a unique message for the token. +func (id *AppID) Message(email string, exp Expiration) []byte { + // check if the application ID is valid, the email is not empty, and the + // expiration time is valid + if id == nil || len(email) == 0 || !exp.Valid() { + return nil + } + // hash the application ID with the email and expiration time + hmsg := sha256.Sum256(append(append(id.Bytes(), []byte(email)...), exp.Bytes()...)) + return hmsg[:] +} + +// VerifyToken method returns true if the token is valid for the application +// ID. If the app ID is nil, the token is nil, the secret is invalid, or the +// email is empty, false is returned. It is used to verify the token by +// checking that the expiration time is valid and the signature is correct. +// The token is verified by hashing the application ID with the email and +// expiration time, and verifying the signature with the secret. +func (id *AppID) VerifyToken(token Token, secret Secret, email string) bool { + // check if the application ID is valid + if id == nil { + return false + } + // check if the expiration time is valid + exp := token.Expiration() + if exp == nil || !exp.Valid() { + return false + } + // check if the token contains a signature + sig := token.Signature() + if len(sig) == 0 { + return false + } + // get the message to verify + msg := id.Message(email, *exp) + if len(msg) == 0 { + return false + } + // verify the token with the secret + return id.Verify(secret, msg, sig) +} + +// hedgedNonce function returns a nonce for the application ID. It is used to +// generate a unique nonce for the application ID by hashing the application +// ID with the message. The nonce is generated by appending the application ID +// with the message and hashing the result with sha256. It helps to prevent +// replay attacks by ensuring that the nonce is unique for each message. +func hedgedNonce(inputs ...[]byte) []byte { + if len(inputs) == 0 || len(inputs[0]) == 0 { + return nil + } + hFn := hmac.New(sha256.New, inputs[0]) + for _, in := range inputs[1:] { + hFn.Write(in) + } + return hFn.Sum(nil) +} diff --git a/token/id_test.go b/token/id_test.go new file mode 100644 index 0000000..8419a2d --- /dev/null +++ b/token/id_test.go @@ -0,0 +1,227 @@ +package token + +import ( + "bytes" + "crypto/hmac" + "crypto/sha256" + "testing" + "time" +) + +func TestStringSetStringAppID(t *testing.T) { + if id := new(AppID).SetString("testID"); id != nil { + t.Errorf("expected nil, got %v", id) + } + app := &App{ + Name: testAppName, + RedirectURI: testRedirectURI, + SessionDuration: testSessionDuration, + } + servicePart := []byte("service-secret") + appPart := []byte("app-secret") + secret := new(Secret).SetParts(servicePart, appPart) + app.SetSecret(secret) + id := app.ID(secret) + if id == nil { + t.Fatalf("error decoding app ID") + } + if id.String() != string(app.Marshal()) { + t.Errorf("expected %s, got %s", string(app.Marshal()), id.String()) + } + newID := new(AppID).SetString(string(app.Marshal())) + if newID == nil { + t.Fatalf("error decoding app ID") + } + if newID.String() != id.String() { + t.Errorf("expected %s, got %s", id.String(), newID.String()) + } +} + +func TestBytesSetBytesAppID(t *testing.T) { + app := &App{ + Name: testAppName, + RedirectURI: testRedirectURI, + SessionDuration: testSessionDuration, + } + servicePart := []byte("service-secret") + appPart := []byte("app-secret") + secret := new(Secret).SetParts(servicePart, appPart) + app.SetSecret(secret) + id := app.ID(secret) + if id == nil { + t.Fatalf("error decoding app ID") + } + if !bytes.Equal(id.Bytes(), app.Marshal()) { + t.Errorf("expected %v, got %v", app.Marshal(), id.Bytes()) + } + newID := new(AppID).SetBytes(app.Marshal()) + if newID == nil { + t.Fatalf("error decoding app ID") + } + if !bytes.Equal(newID.Bytes(), id.Bytes()) { + t.Errorf("expected %v, got %v", id.Bytes(), newID.Bytes()) + } + var nilID *AppID + nilID = nilID.SetBytes(app.Marshal()) + if nilID == nil { + t.Fatalf("error decoding app ID") + } + if !bytes.Equal(nilID.Bytes(), id.Bytes()) { + t.Errorf("expected %v, got %v", id.Bytes(), nilID.Bytes()) + } + // nil app ID + if nilID = new(AppID).SetBytes(nil); nilID != nil { + t.Errorf("expected nil, got %v", nilID) + } + var nilAppID *AppID + if bNilAppID := nilAppID.Bytes(); bNilAppID != nil { + t.Errorf("expected nil, got %v", bNilAppID) + } +} + +func TestPrivKeySignVerifyAppID(t *testing.T) { + var nilAppID *AppID + if privKey := nilAppID.PrivKey(*testAppSecret); privKey != nil { + t.Errorf("expected nil, got %v", privKey) + } + if nilSig := nilAppID.Sign(*testAppSecret, []byte("test data")); nilSig != nil { + t.Errorf("expected nil, got %v", nilSig) + } + if nilVerify := nilAppID.Verify(*testAppSecret, []byte("test data"), []byte("test sig")); nilVerify { + t.Errorf("expected signature to be invalid") + } + badSecret := new(Secret) + if nilPrivKey := new(AppID).PrivKey(*badSecret); nilPrivKey != nil { + t.Errorf("expected nil, got %v", nilPrivKey) + } + if privKey := new(AppID).PrivKey(*testAppSecret); privKey != nil { + t.Errorf("expected nil, got %v", privKey) + } + if sig := new(AppID).Sign(*testAppSecret, []byte("test data")); sig != nil { + t.Errorf("expected nil, got %v", sig) + } + if new(AppID).Verify(*testAppSecret, []byte("test data"), []byte("test sig")) { + t.Errorf("expected signature to be invalid") + } + app := &App{ + Name: testAppName, + RedirectURI: testRedirectURI, + SessionDuration: testSessionDuration, + } + servicePart := []byte("service-secret") + appPart := []byte("app-secret") + secret := new(Secret).SetParts(servicePart, appPart) + app.SetSecret(secret) + id := app.ID(secret) + if id == nil { + t.Fatalf("error decoding app ID") + } + data := []byte("test data") + sig := id.Sign(*testAppSecret, data) + if sig == nil { + t.Fatalf("error signing data") + } + if !id.Verify(*testAppSecret, data, sig) { + t.Errorf("expected signature to be valid") + } + if id.Verify(*testAppSecret, data, []byte("invalid sig")) { + t.Errorf("expected signature to be invalid") + } +} + +func TestMessage(t *testing.T) { + if privKey := new(AppID).PrivKey(*testAppSecret); privKey != nil { + t.Errorf("expected nil, got %v", privKey) + } + if sig := new(AppID).Sign(*testAppSecret, []byte("test data")); sig != nil { + t.Errorf("expected nil, got %v", sig) + } +} + +func TestGenerateTokenVerifyToken(t *testing.T) { + t.Parallel() + var nilAppID *AppID + if res := nilAppID.GenerateToken(nil, ""); res != nil { + t.Errorf("expected nil, got %v", res) + } + if nilAppID.VerifyToken(nil, *testAppSecret, "") { + t.Errorf("expected token to be invalid") + } + if res := new(AppID).GenerateToken(nil, ""); res != nil { + t.Errorf("expected nil, got %v", res) + } + app := &App{ + Name: testAppName, + RedirectURI: testRedirectURI, + SessionDuration: minDuration, + } + servicePart := []byte("service-secret") + appPart := []byte("app-secret") + secret := new(Secret).SetParts(servicePart, appPart) + app.SetSecret(secret) + id := app.ID(secret) + if id == nil { + t.Fatalf("error decoding app ID") + } + email := "test@email.com" + if id.VerifyToken([]byte{}, *testAppSecret, email) { + t.Errorf("expected token to be invalid") + } + dummyToken := new(Token).SetExpiration(*new(Expiration).SetDuration(minDuration * 3)) + if id.VerifyToken(*dummyToken, *testAppSecret, email) { + t.Errorf("expected token to be invalid") + } + dummyToken = new(Token).SetSignature([]byte("test")) + if id.VerifyToken(*dummyToken, *testAppSecret, email) { + t.Errorf("expected token to be invalid") + } + if invalidToken := id.GenerateToken(*testAppSecret, ""); invalidToken != nil { + t.Fatalf("expected nil, got %v", invalidToken) + } + + token := id.GenerateToken(*testAppSecret, email) + if token == nil { + t.Fatalf("error creating token") + } + if id.VerifyToken(token, *testAppSecret, "") { + t.Errorf("expected token to be invalid") + } + if !id.VerifyToken(token, *testAppSecret, email) { + t.Errorf("expected token to be valid") + } + time.Sleep(app.SessionDuration + time.Second) + if id.VerifyToken(token, *testAppSecret, email) { + t.Errorf("expected token to be invalid") + } + if id.VerifyToken(nil, *testAppSecret, email) { + t.Errorf("expected token to be invalid") + } + exp := new(Expiration).SetDuration(minDuration) + if id.VerifyToken(exp.Marshal(), *testAppSecret, email) { + t.Errorf("expected token to be invalid") + } +} + +func Test_hedgedNonce(t *testing.T) { + if res := hedgedNonce(); res != nil { + t.Errorf("expected nil, got %v", res) + } + if res := hedgedNonce(nil); res != nil { + t.Errorf("expected nil, got %v", res) + } + seed := []byte("test") + hFn := hmac.New(sha256.New, seed) + expected := hFn.Sum(nil) + if res := hedgedNonce(seed); !bytes.Equal(res, expected) { + t.Errorf("expected %v, got %v", expected, res) + } + + seed = []byte("test") + hFn = hmac.New(sha256.New, seed) + in := []byte("data") + hFn.Write(in) + expected = hFn.Sum(nil) + if res := hedgedNonce(seed, in); !bytes.Equal(res, expected) { + t.Errorf("expected %v, got %v", expected, res) + } +} diff --git a/token/secret.go b/token/secret.go new file mode 100644 index 0000000..9b54921 --- /dev/null +++ b/token/secret.go @@ -0,0 +1,70 @@ +package token + +import "crypto/sha256" + +// secretHashSize is the size of the secret hash. It is used to determine +// the size of the secret when it is hashed. The hash is created by hashing +// the secret to a sha256 size. The hash is used to sign and verify tokens. +// The hash is also used to create the app ID and it is part of it. +const secretHashSize = 12 + +// Secret represents a secret that is used to sign and verify tokens. It is +// a wrapper around a byte slice that provides additional methods for setting +// and getting the secret. It should have at least 2 parts, each hashed to a +// sha256 size. +type Secret []byte + +// SetParts method sets the secret's parts from a slice of byte slices. If the +// secret is nil, a new secret is created. If the parts are empty, the secret +// is not set. The parts are hashed to a sha256 size and concatenated to form +// the secret. +func (s *Secret) SetParts(raw ...[]byte) *Secret { + // if no secret is provided, initialize a new one + if s == nil { + s = new(Secret) + } + // hash each part to a sha256 size and concatenate them + newParts := []byte{} + for _, part := range raw { + if len(part) == 0 { + continue + } + hsecret := sha256.Sum256(part) + newParts = append(newParts, hsecret[:]...) + } + // if there are new parts, append them to the secret + if len(newParts) != 0 { + *s = append(*s, newParts...) + } + return s +} + +// Bytes method returns the secret as a byte slice. +func (s *Secret) Bytes() []byte { + return []byte(*s) +} + +// Hash method returns the hash of the secret as a byte slice. The hash is +// created by hashing the secret to a sha256 size. The hash is used to create +// the app ID and it is part of it, but is also used to sign and verify the +// user sessions in the token generation process. +func (s *Secret) Hash() []byte { + if s == nil { + return nil + } + // hash the secret to a sha256 size + h := sha256.Sum256(*s) + return h[:secretHashSize] +} + +// Valid method returns true if the secret is valid, false otherwise. A secret +// is considered valid if it has more than 1 part, and each part is hashed to +// a sha256 size. +func (s *Secret) Valid() bool { + if s == nil { + return false + } + // secret is valid if it has more than 1 part, and each part is hashed + // to a sha256 size + return len(*s)%sha256.Size == 0 && len(*s) > sha256.Size +} diff --git a/token/secret_test.go b/token/secret_test.go new file mode 100644 index 0000000..981819f --- /dev/null +++ b/token/secret_test.go @@ -0,0 +1,76 @@ +package token + +import ( + "bytes" + "crypto/sha256" + "testing" +) + +func TestSetPartsSecret(t *testing.T) { + servicePart := []byte("service-secret") + appPart := []byte("app-secret") + valid := new(Secret).SetParts(servicePart, appPart) + if valid == nil { + t.Errorf("expected Secret, got nil") + } + hServicePart := sha256.Sum256(servicePart) + hAppPart := sha256.Sum256(appPart) + expected := append(hServicePart[:], hAppPart[:]...) + if !bytes.Equal(valid.Bytes(), expected) { + t.Errorf("expected %x, got %x", expected[:], valid.Bytes()) + } + if noServicePart := new(Secret).SetParts(nil, appPart); !bytes.Equal(noServicePart.Bytes(), hAppPart[:]) { + t.Errorf("expected nil, got %x", noServicePart) + } + if noAppPart := new(Secret).SetParts(servicePart, nil); !bytes.Equal(noAppPart.Bytes(), hServicePart[:]) { + t.Errorf("expected nil, got %x", noAppPart) + } + var nilSecret *Secret + valid = nilSecret.SetParts(servicePart, appPart) + if !bytes.Equal(valid.Bytes(), expected) { + t.Errorf("expected nil, got %v", nilSecret) + } +} + +func TestValidSecret(t *testing.T) { + var nilSecret *Secret + if valid := nilSecret.Valid(); valid { + t.Errorf("expected false, got %v", valid) + } + if valid := new(Secret).Valid(); valid { + t.Errorf("expected false, got %v", valid) + } + servicePart := []byte("service-secret") + singlePart := new(Secret).SetParts(servicePart) + if valid := singlePart.Valid(); valid { + t.Errorf("expected false, got %v", valid) + } + appPart := []byte("app-secret") + valid := new(Secret).SetParts(servicePart, appPart) + if !valid.Valid() { + t.Errorf("expected true, got false") + } +} + +func TestSecretHash(t *testing.T) { + servicePart := []byte("service-secret") + appPart := []byte("app-secret") + secret := new(Secret).SetParts(servicePart, appPart) + h := secret.Hash() + if h == nil { + t.Errorf("expected hash, got nil") + } + if len(h) != secretHashSize { + t.Errorf("expected hash size %d, got %d", secretHashSize, len(h)) + } + hSecret := sha256.Sum256(secret.Bytes()) + if !bytes.Equal(h, hSecret[:secretHashSize]) { + t.Errorf("expected %x, got %x", hSecret[:secretHashSize], h) + } + // try to hash a nil secret + var nilSecret *Secret + h = nilSecret.Hash() + if h != nil { + t.Errorf("expected nil, got %x", h) + } +} diff --git a/token/token.go b/token/token.go new file mode 100644 index 0000000..ce90604 --- /dev/null +++ b/token/token.go @@ -0,0 +1,155 @@ +package token + +import "bytes" + +// Token is a type that represents a user token. It is a wrapper around a byte +// slice that provides additional methods for setting and getting the token. It +// should have 2 parts, the first part is the expiration time, and the second +// part is the signature. +type Token []byte + +// String method returns the token as a string. It is useful for encoding the +// token. If the token is nil, an empty string is returned. It internally calls +// the Bytes method to get the token as a byte slice. +func (t *Token) String() string { + if t == nil { + return "" + } + return string(t.Bytes()) +} + +// SetString method sets the token from a string. It is useful for decoding the +// token. The string should be the token's expiration time and signature joined +// by the token separator. If the token is invalid, the token is not set. It +// internally calls the SetBytes method to set the token from a byte slice. +func (t *Token) SetString(data string) *Token { + if t == nil { + t = new(Token) + } + return t.SetBytes([]byte(data)) +} + +// Bytes method returns the token as a byte slice. It is useful for encoding +// the token. If the token is nil, nil is returned. It internally calls the +// parts method to get the token's expiration time and signature as byte +// slices. It checks that the parts are valid before returning the token +// as a byte slice. +func (t *Token) Bytes() []byte { + // check if the token is nil + if t == nil { + return nil + } + // check if the token has valid parts + if _, _, ok := t.parts(); !ok { + return nil + } + // return the token as a byte slice + return []byte(*t) +} + +// SetBytes method sets the token from a byte slice. It is useful for +// decoding the token. The byte slice should be the token's expiration time +// and signature joined by the token separator. If the token is invalid, the +// token is not set. It internally calls the parts method to get the token's +// expiration time and signature as byte slices. It checks that the parts are +// valid before setting the token from the byte slice. +func (t *Token) SetBytes(data []byte) *Token { + // if no token is provided, create a new one + if t == nil { + t = new(Token) + } + // generate a new token from the data + ntoken := &Token{} + *ntoken = data + // check if the new token has valid parts + if _, _, ok := ntoken.parts(); !ok { + return t + } + // set the token to the new token + *t = data + return t +} + +// Expiration method returns the token's expiration time. It is useful for +// getting the expiration time. If the token is nil, nil is returned. It +// internally calls the parts method to get the token's expiration time and +// signature as byte slices. It checks that the expiration time is valid before +// returning it. +func (t *Token) Expiration() *Expiration { + if rawExp, _, ok := t.parts(); ok { + return new(Expiration).Unmarshal(rawExp) + } + return nil +} + +// SetExpiration method sets the token's expiration time. If the token is nil, +// a new token is created. If the expiration time is invalid, the token is not +// set. It internally calls the parts method to replace the token's expiration +// time with the new expiration time. It checks that the expiration time is +// valid before setting it. +func (t *Token) SetExpiration(exp Expiration) *Token { + // if no token is provided, create a new one + if t == nil { + t = new(Token) + } + // if expiration is invalid, return the current token + if !exp.Valid() { + return t + } + // create a base content with the new expiration and no signature + baseContent := append(exp.Marshal(), tokenSeparator) + // get the current signature, if there is one, return a token with the + // base content and no signature + sig := t.Signature() + if sig == nil { + return t.SetBytes(baseContent) + } + // if there is a signature, update the token to replace the expiration + // part with the new expiration + return t.SetBytes(append(baseContent, sig...)) +} + +// Signature method returns the token's signature. If the token is nil, nil is +// returned. It internally calls the parts method to get the signature part as +// byte slices. +func (t *Token) Signature() []byte { + _, sig, _ := t.parts() + return sig +} + +// SetSignature method sets the token's signature. If the token is nil, a +// new token is created. If the signature is nil, the token is not set. It +// internally calls the parts method to replace the token's signature with +// the new signature. +func (t *Token) SetSignature(sig []byte) *Token { + // if no token is provided, create a new one + if t == nil { + t = new(Token) + } + // create a base content with no expiration and the new signature + baseContent := append([]byte{tokenSeparator}, sig...) + // get the current expiration, if there is none, return a token with the + // base content and no expiration + exp := t.Expiration() + if exp == nil { + return t.SetBytes(baseContent) + } + // if there is an expiration, update the token to replace the signature + // part with the new signature + return t.SetBytes(append(exp.Marshal(), baseContent...)) +} + +// parts private method returns the token's expiration time and signature as +// byte slices. It also returns a boolean indicating if the token has valid +// parts. If the token is nil, or the parts are invalid, the parts are nil and +// the boolean is false. It splits the token by the token separator and checks +// that the result has 2 parts (expiration time and signature). +func (t *Token) parts() ([]byte, []byte, bool) { + if t == nil { + return nil, nil, false + } + if p := bytes.Split([]byte(*t), []byte{tokenSeparator}); len(p) == 2 { + return p[0], p[1], true + } + return nil, nil, false +} diff --git a/token/token_test.go b/token/token_test.go new file mode 100644 index 0000000..6b9c80a --- /dev/null +++ b/token/token_test.go @@ -0,0 +1,159 @@ +package token + +import ( + "bytes" + "testing" + "time" +) + +func TestStringSetStringToken(t *testing.T) { + var token *Token + token.SetString("test") + if token.String() != "" { + t.Errorf("expected empty string, got %s", token.String()) + } + token = token.SetString("test") + if token.String() != "" { + t.Errorf("expected empty string, got %s", token.String()) + } + exp := new(Expiration).SetDuration(minDuration * 2) + expected := string(exp.Marshal()) + string(tokenSeparator) + token.SetString(expected) + if token.String() != expected { + t.Errorf("expected %s, got %s", expected, token.String()) + } + + expected = string(tokenSeparator) + "testSignature" + token.SetString(expected) + if token.String() != expected { + t.Errorf("expected %s, got %s", expected, token.String()) + } + + expected = string(exp.Marshal()) + string(tokenSeparator) + "testSignature" + token.SetString(expected) + if token.String() != expected { + t.Errorf("expected %s, got %s", expected, token.String()) + } +} + +func TestBytesSetBytesToken(t *testing.T) { + var token *Token + token.SetBytes([]byte("test")) + if token.Bytes() != nil { + t.Errorf("expected nil, got %v", token.Bytes()) + } + token = token.SetBytes([]byte("test")) + if token.Bytes() != nil { + t.Errorf("expected nil, got %v", token.Bytes()) + } + exp := new(Expiration).SetDuration(minDuration * 2) + onlyExp := append(exp.Marshal(), tokenSeparator) + token.SetBytes(onlyExp) + if !bytes.Equal(token.Bytes(), onlyExp) { + t.Errorf("expected %v, got %v", onlyExp, token.Bytes()) + } + + onlySign := append([]byte{tokenSeparator}, []byte("testSignature")...) + token.SetBytes(onlySign) + if !bytes.Equal(token.Bytes(), onlySign) { + t.Errorf("expected %v, got %v", onlySign, token.Bytes()) + } + + fullToken := append(append(exp.Marshal(), tokenSeparator), []byte("testSignature")...) + token.SetBytes(fullToken) + if !bytes.Equal(token.Bytes(), fullToken) { + t.Errorf("expected %v, got %v", fullToken, token.Bytes()) + } +} + +func TestExpirationSetExpirationToken(t *testing.T) { + exp := new(Expiration).SetDuration(minDuration * 2) + var token *Token + token.SetExpiration(*exp) + if token.String() != "" { + t.Errorf("expected empty string, got %s", token.String()) + } + token = token.SetExpiration(Expiration(time.Now().Add(-time.Second))) + if token.String() != "" { + t.Errorf("expected empty string, got %s", token.String()) + } + token.SetExpiration(*exp) + justExpExp := token.Expiration() + if justExpExp == nil { + t.Fatalf("expected valid expiration, got nil") + } + if justExpExp.String() != exp.String() { + t.Errorf("expected %v, got %v", exp, justExpExp) + } + token.SetSignature([]byte("test")) + completeTokenExp := token.Expiration() + if completeTokenExp == nil { + t.Fatalf("expected valid expiration, got nil") + } + if completeTokenExp.String() != exp.String() { + t.Errorf("expected %v, got %v", exp, completeTokenExp) + } + exp = new(Expiration).SetDuration(minDuration * 3) + token.SetExpiration(*exp) + newExpExp := token.Expiration() + if newExpExp == nil || newExpExp.String() != exp.String() { + t.Errorf("expected %v, got %v", exp, newExpExp) + } + validStr := token.String() + newtoken := new(Token).SetString(validStr) + newtokenExp := newtoken.Expiration() + if newtokenExp == nil || newtokenExp.String() != exp.String() { + t.Errorf("expected %v, got %v", exp, newtokenExp) + } +} + +func TestSignatureSetSignatureToken(t *testing.T) { + var token *Token + token.SetSignature([]byte("test")) + if token.String() != "" { + t.Errorf("expected empty string, got %s", token.String()) + } + token = token.SetSignature([]byte("test")) + justSignSign := token.Signature() + if !bytes.Equal(justSignSign, []byte("test")) { + t.Errorf("expected %v, got %v", []byte("test"), justSignSign) + } + token.SetExpiration(*new(Expiration).SetDuration(minDuration * 3)) + completeTokenSign := token.Signature() + if !bytes.Equal(completeTokenSign, []byte("test")) { + t.Errorf("expected %v, got %v", []byte("test"), completeTokenSign) + } + validStr := token.String() + newtoken := new(Token).SetString(validStr) + newtokenSign := newtoken.Signature() + if !bytes.Equal(newtokenSign, []byte("test")) { + t.Errorf("expected %v, got %v", []byte("test"), newtokenSign) + } +} + +func Test_partsToken(t *testing.T) { + var token *Token + if _, _, ok := token.parts(); ok { + t.Errorf("expected false, got true") + } + exp := new(Expiration).SetDuration(minDuration * 2) + token = token.SetExpiration(*exp) + rawExp, _, ok := token.parts() + if !ok { + t.Errorf("expected false, got true") + } + if !bytes.Equal(rawExp, exp.Marshal()) { + t.Errorf("expected %v, got %v", exp.Marshal(), rawExp) + } + token.SetSignature([]byte("test")) + rawExp, sign, ok := token.parts() + if !ok { + t.Errorf("expected true, got false") + } + if !bytes.Equal(rawExp, exp.Marshal()) { + t.Errorf("expected %v, got %v", exp.Marshal(), rawExp) + } + if !bytes.Equal(sign, []byte("test")) { + t.Errorf("expected %v, got %v", []byte("test"), sign) + } +}